Coverage for src/fastoai/models/_types.py: 75%
81 statements
« prev ^ index » next coverage.py v7.6.8, created at 2024-12-06 09:34 +0800
« prev ^ index » next coverage.py v7.6.8, created at 2024-12-06 09:34 +0800
1from typing import (
2 Annotated,
3 Any,
4 Literal,
5 Self,
6 Union,
7 get_args,
8 get_origin,
9)
11import sqlalchemy as sa
12from pydantic import BaseModel, RootModel
13from sqlalchemy.ext.mutable import Mutable, MutableList
14from sqlmodel import JSON, Enum, String
17class BaseModelType(sa.types.TypeDecorator[BaseModel]):
18 """This is a custom SQLAlchemy field that allows easy serialization between database JSONB types and Pydantic models"""
20 impl = JSON
22 def __init__(
23 self,
24 pydantic_model_class: type[BaseModel],
25 *args,
26 **kwargs,
27 ):
28 super().__init__(*args, **kwargs)
29 self.pydantic_model_class = pydantic_model_class
31 def process_bind_param(self, value: BaseModel | list[BaseModel] | None, _): # type: ignore
32 """Convert python native type to JSON string before storing in the database"""
33 match value:
34 case None:
35 return None
36 case list():
37 return [v.model_dump(mode="json") for v in value]
38 case BaseModel():
39 return value.model_dump(mode="json")
40 case _:
41 return value
43 def process_result_value(self, value: Any, _): # type: ignore
44 """Convert JSON string back to Python object after retrieving from the database"""
45 v = None if value is None else self.pydantic_model_class.model_validate(value)
46 return v.root if isinstance(v, RootModel) else v
49class UnionModelType(sa.types.TypeDecorator[RootModel]):
50 """This is a custom SQLAlchemy field that allows easy serialization between database JSONB types and Pydantic models"""
52 impl = String
54 def __init__(
55 self,
56 pydantic_model_class: type[RootModel],
57 *args,
58 **kwargs,
59 ):
60 super().__init__(*args, **kwargs)
61 self.pydantic_model_class = pydantic_model_class
63 def process_bind_param(self, value: RootModel | list[BaseModel] | None, _): # type: ignore
64 """Convert python native type to JSON string before storing in the database"""
65 match value:
66 case None:
67 return None
68 case list():
69 return [v.model_dump(mode="json") for v in value]
70 case BaseModel():
71 return value.model_dump(mode="json")
72 case _:
73 return value
75 def process_result_value(self, value: Any, _): # type: ignore
76 """Convert JSON string back to Python object after retrieving from the database"""
77 return (
78 None
79 if value is None
80 else self.pydantic_model_class.model_validate(value).root
81 )
84class MutableBaseModel(Mutable, BaseModel):
85 """This is a hack that is intended to allow SQLAlchemy detect changes in JSON field that is a pydantic model"""
87 def __setattr__(self, name: str, value: Any) -> None:
88 """Allows SQLAlchmey Session to track mutable behavior"""
89 super().__setattr__(name, value)
90 self.changed()
92 @classmethod
93 def coerce(cls, key: str, value: Any) -> Self | None:
94 """Convert JSON to pydantic model object allowing for mutable behavior"""
95 if isinstance(value, cls) or value is None:
96 return value
98 if isinstance(value, str):
99 return cls.model_validate_json(value)
101 if isinstance(value, dict):
102 return cls.model_validate(value)
104 return super().coerce(key, value)
107def _is_subclass_of_base_model(t: type):
108 try:
109 return issubclass(t, BaseModel)
110 except TypeError:
111 return False
114def _is_base_model(t: type):
115 def _is_union_of_base_models(t: type):
116 return get_origin(t) is Union and any(
117 issubclass(a, BaseModel) for a in get_args(t)
118 )
120 return (
121 _is_subclass_of_base_model(t)
122 or _is_union_of_base_models(t)
123 or (get_origin(t) is Annotated and _is_base_model(get_args(t)[0]))
124 )
127def as_sa_type(type_: type) -> type:
128 try:
129 if issubclass(type_, BaseModel):
130 return type(type_.__name__, (type_, MutableBaseModel), {}).as_mutable( # type: ignore
131 BaseModelType(type_)
132 )
133 except TypeError:
134 pass
136 origin = get_origin(type_)
137 args = get_args(type_)
138 t: type = args[0]
140 if origin is Annotated:
141 return as_sa_type(t)
143 if origin is list and _is_base_model(t):
144 return MutableList[t].as_mutable(BaseModelType(RootModel[type_])) # type: ignore
146 if origin is Union:
147 new_type = RootModel[type_]
148 return type(new_type.__name__, (new_type, MutableBaseModel), {}).as_mutable( # type: ignore
149 UnionModelType(new_type)
150 )
152 if origin is Literal:
153 return Enum(*args) # type: ignore
155 raise ValueError(f"Unsupported type {type_}")