Coverage for src/fastoai/models/_types.py: 75%

81 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-12-06 09:34 +0800

1from typing import ( 

2 Annotated, 

3 Any, 

4 Literal, 

5 Self, 

6 Union, 

7 get_args, 

8 get_origin, 

9) 

10 

11import sqlalchemy as sa 

12from pydantic import BaseModel, RootModel 

13from sqlalchemy.ext.mutable import Mutable, MutableList 

14from sqlmodel import JSON, Enum, String 

15 

16 

17class BaseModelType(sa.types.TypeDecorator[BaseModel]): 

18 """This is a custom SQLAlchemy field that allows easy serialization between database JSONB types and Pydantic models""" 

19 

20 impl = JSON 

21 

22 def __init__( 

23 self, 

24 pydantic_model_class: type[BaseModel], 

25 *args, 

26 **kwargs, 

27 ): 

28 super().__init__(*args, **kwargs) 

29 self.pydantic_model_class = pydantic_model_class 

30 

31 def process_bind_param(self, value: BaseModel | list[BaseModel] | None, _): # type: ignore 

32 """Convert python native type to JSON string before storing in the database""" 

33 match value: 

34 case None: 

35 return None 

36 case list(): 

37 return [v.model_dump(mode="json") for v in value] 

38 case BaseModel(): 

39 return value.model_dump(mode="json") 

40 case _: 

41 return value 

42 

43 def process_result_value(self, value: Any, _): # type: ignore 

44 """Convert JSON string back to Python object after retrieving from the database""" 

45 v = None if value is None else self.pydantic_model_class.model_validate(value) 

46 return v.root if isinstance(v, RootModel) else v 

47 

48 

49class UnionModelType(sa.types.TypeDecorator[RootModel]): 

50 """This is a custom SQLAlchemy field that allows easy serialization between database JSONB types and Pydantic models""" 

51 

52 impl = String 

53 

54 def __init__( 

55 self, 

56 pydantic_model_class: type[RootModel], 

57 *args, 

58 **kwargs, 

59 ): 

60 super().__init__(*args, **kwargs) 

61 self.pydantic_model_class = pydantic_model_class 

62 

63 def process_bind_param(self, value: RootModel | list[BaseModel] | None, _): # type: ignore 

64 """Convert python native type to JSON string before storing in the database""" 

65 match value: 

66 case None: 

67 return None 

68 case list(): 

69 return [v.model_dump(mode="json") for v in value] 

70 case BaseModel(): 

71 return value.model_dump(mode="json") 

72 case _: 

73 return value 

74 

75 def process_result_value(self, value: Any, _): # type: ignore 

76 """Convert JSON string back to Python object after retrieving from the database""" 

77 return ( 

78 None 

79 if value is None 

80 else self.pydantic_model_class.model_validate(value).root 

81 ) 

82 

83 

84class MutableBaseModel(Mutable, BaseModel): 

85 """This is a hack that is intended to allow SQLAlchemy detect changes in JSON field that is a pydantic model""" 

86 

87 def __setattr__(self, name: str, value: Any) -> None: 

88 """Allows SQLAlchmey Session to track mutable behavior""" 

89 super().__setattr__(name, value) 

90 self.changed() 

91 

92 @classmethod 

93 def coerce(cls, key: str, value: Any) -> Self | None: 

94 """Convert JSON to pydantic model object allowing for mutable behavior""" 

95 if isinstance(value, cls) or value is None: 

96 return value 

97 

98 if isinstance(value, str): 

99 return cls.model_validate_json(value) 

100 

101 if isinstance(value, dict): 

102 return cls.model_validate(value) 

103 

104 return super().coerce(key, value) 

105 

106 

107def _is_subclass_of_base_model(t: type): 

108 try: 

109 return issubclass(t, BaseModel) 

110 except TypeError: 

111 return False 

112 

113 

114def _is_base_model(t: type): 

115 def _is_union_of_base_models(t: type): 

116 return get_origin(t) is Union and any( 

117 issubclass(a, BaseModel) for a in get_args(t) 

118 ) 

119 

120 return ( 

121 _is_subclass_of_base_model(t) 

122 or _is_union_of_base_models(t) 

123 or (get_origin(t) is Annotated and _is_base_model(get_args(t)[0])) 

124 ) 

125 

126 

127def as_sa_type(type_: type) -> type: 

128 try: 

129 if issubclass(type_, BaseModel): 

130 return type(type_.__name__, (type_, MutableBaseModel), {}).as_mutable( # type: ignore 

131 BaseModelType(type_) 

132 ) 

133 except TypeError: 

134 pass 

135 

136 origin = get_origin(type_) 

137 args = get_args(type_) 

138 t: type = args[0] 

139 

140 if origin is Annotated: 

141 return as_sa_type(t) 

142 

143 if origin is list and _is_base_model(t): 

144 return MutableList[t].as_mutable(BaseModelType(RootModel[type_])) # type: ignore 

145 

146 if origin is Union: 

147 new_type = RootModel[type_] 

148 return type(new_type.__name__, (new_type, MutableBaseModel), {}).as_mutable( # type: ignore 

149 UnionModelType(new_type) 

150 ) 

151 

152 if origin is Literal: 

153 return Enum(*args) # type: ignore 

154 

155 raise ValueError(f"Unsupported type {type_}")