Coverage for fss\common\persistence\sqlmodel_impl.py: 35%

109 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-11 19:09 +0800

1"""Sqlmodel impl that do database operations""" 

2 

3from typing import Generic, TypeVar, List, Any, Type, Union 

4 

5from fastapi_pagination.ext.sqlmodel import paginate 

6from pydantic import BaseModel 

7from sqlmodel import SQLModel, select, func, insert, update, delete 

8from sqlmodel.ext.asyncio.session import AsyncSession 

9 

10from fss.common.enum.enum import SortEnum 

11from fss.common.persistence.base_mapper import BaseMapper 

12from fss.middleware.db_session_middleware import db 

13 

14ModelType = TypeVar("ModelType", bound=SQLModel) 

15CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) 

16UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) 

17SchemaType = TypeVar("SchemaType", bound=BaseModel) 

18T = TypeVar("T", bound=SQLModel) 

19 

20 

21class SqlModelMapper(Generic[ModelType], BaseMapper): 

22 def __init__(self, model: Type[ModelType]): 

23 self.model = model 

24 self.db = db 

25 

26 def get_db_session(self) -> Type[Any]: 

27 return self.db 

28 

29 async def insert( 

30 self, 

31 *, 

32 data: Union[ModelType, SchemaType], 

33 db_session: Union[AsyncSession, None] = None, 

34 ) -> int: 

35 db_session = db_session or self.db.session 

36 orm_data = self.model.model_validate(data) 

37 db_session.add(orm_data) 

38 return orm_data 

39 

40 async def insert_batch( 

41 self, *, data_list: List[Any], db_session: Any = None 

42 ) -> int: 

43 db_session = db_session or self.db.session 

44 orm_datas = [ 

45 self.model.model_validate(data) 

46 if not isinstance(data, self.model) 

47 else data 

48 for data in data_list 

49 ] 

50 statement = insert(self.model).values([data.model_dump() for data in orm_datas]) 

51 await db_session.execute(statement) 

52 return len(data_list) 

53 

54 async def select_by_id(self, *, id: Any, db_session: Any = None) -> Any: 

55 db_session = db_session or self.db.session 

56 statement = select(self.model).where(self.model.id == id) 

57 response = await db_session.execute(statement) 

58 return response.scalar_one_or_none() 

59 

60 async def select_by_ids( 

61 self, *, ids: List[Any], batch_size: int = 1000, db_session: Any = None 

62 ) -> List[Any]: 

63 db_session = db_session or self.db.session 

64 result_set = [] 

65 for i in range(0, len(ids), batch_size): 

66 batch_ids = ids[i : i + batch_size] 

67 statement = select(self.model).where(self.model.id.in_(batch_ids)) 

68 results = await db_session.exec(statement).all() 

69 result_set.extend(results) 

70 return result_set 

71 

72 async def select_count(self, *, db_session: Any = None) -> int: 

73 db_session = db_session or self.db.session 

74 response = await db_session.execute( 

75 select(func.count()).select_from(select(self.model).subquery()) 

76 ) 

77 return response.scalar_one() 

78 

79 async def select_list( 

80 self, *, page: int = 1, size: int = 100, query: Any, db_session: Any = None 

81 ) -> List[Any]: 

82 db_session = db_session or self.db.session 

83 if query is None: 

84 query = ( 

85 select(self.model) 

86 .offset((page - 1) * size) 

87 .limit(size) 

88 .order_by(self.model.id) 

89 ) 

90 response = await db_session.execute(query) 

91 return response.scalars().all() 

92 

93 async def select_list_ordered( 

94 self, 

95 *, 

96 page: int = 1, 

97 size: int = 100, 

98 query: Any, 

99 order_by: Any, 

100 sort_order: Any, 

101 db_session: Any = None, 

102 ) -> List[Any]: 

103 db_session = db_session or self.db.session 

104 columns = self.model.__table__.columns 

105 if order_by is None or order_by not in columns: 

106 order_by = "id" 

107 if sort_order == SortEnum.ascending: 

108 query = ( 

109 select(self.model) 

110 .offset((page - 1) * size) 

111 .limit(size) 

112 .order_by(columns[order_by].asc()) 

113 ) 

114 else: 

115 query = ( 

116 select(self.model) 

117 .offset((page - 1) * size) 

118 .limit(size) 

119 .order_by(columns[order_by].desc()) 

120 ) 

121 response = await db_session.execute(query) 

122 return response.scalars().all() 

123 

124 async def select_list_page( 

125 self, *, params: Any, query: Any, db_session: Any = None 

126 ) -> List[Any]: 

127 db_session = db_session or self.db.session 

128 if query is None: 

129 query = select(self.model) 

130 response = await paginate(db_session, query, params) 

131 return response 

132 

133 async def select_list_page_ordered( 

134 self, 

135 *, 

136 params: Any, 

137 query: Any, 

138 order_by: Any, 

139 sort_order: Any, 

140 db_session: Any = None, 

141 ) -> List[Any]: 

142 db_session = db_session or self.db.session 

143 columns = self.model.__table__.columns 

144 if order_by is None or order_by not in columns: 

145 order_by = "id" 

146 if query is None: 

147 if sort_order == SortEnum.ascending: 

148 query = select(self.model).order_by(columns[order_by].asc()) 

149 else: 

150 query = select(self.model).order_by(columns[order_by].desc()) 

151 return await paginate(db_session, query, params) 

152 

153 async def update_by_id(self, *, data: Any, db_session: Any = None) -> int: 

154 db_session = db_session or self.db.session 

155 query = select(self.model).where(self.model.id == data.id) 

156 result = await db_session.execute(query) 

157 if result is None: 

158 return 0 

159 db_data = result.scalar_one() 

160 for attr, value in data.items(): 

161 setattr(db_data, attr, value) 

162 db_session.add(db_data) 

163 return self.count_affected_rows(db_data) 

164 

165 async def update_batch_by_ids( 

166 self, *, data_list: List[Any], db_session: Any = None 

167 ) -> int: 

168 db_session = db_session or self.db.session 

169 for data in data_list: 

170 if hasattr(data, "id"): 

171 statement = ( 

172 update(self.model) 

173 .where(self.model.id == data.id) 

174 .values(**data.dict(exclude_unset=True)) 

175 ) 

176 await db_session.execute(statement) 

177 return len(data_list) 

178 

179 async def delete_by_id(self, *, id: Any, db_session: Any = None) -> int: 

180 db_session = db_session or self.db.session 

181 response = await db_session.execute( 

182 select(self.model).where(self.model.id == id) 

183 ) 

184 data = response.scalar_one() 

185 await db_session.delete(data) 

186 return 1 

187 

188 async def delete_batch_by_ids( 

189 self, *, ids: List[Any], db_session: Any = None 

190 ) -> int: 

191 db_session = db_session or self.db.session 

192 statement = delete(self.model).where(self.model.id.in_(ids)) 

193 result = await db_session.execute(statement) 

194 return result.rowcount