Coverage for fss\common\persistence\sqlmodel_impl.py: 35%
109 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-11 19:09 +0800
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-11 19:09 +0800
1"""Sqlmodel impl that do database operations"""
3from typing import Generic, TypeVar, List, Any, Type, Union
5from fastapi_pagination.ext.sqlmodel import paginate
6from pydantic import BaseModel
7from sqlmodel import SQLModel, select, func, insert, update, delete
8from sqlmodel.ext.asyncio.session import AsyncSession
10from fss.common.enum.enum import SortEnum
11from fss.common.persistence.base_mapper import BaseMapper
12from fss.middleware.db_session_middleware import db
14ModelType = TypeVar("ModelType", bound=SQLModel)
15CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
16UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
17SchemaType = TypeVar("SchemaType", bound=BaseModel)
18T = TypeVar("T", bound=SQLModel)
21class SqlModelMapper(Generic[ModelType], BaseMapper):
22 def __init__(self, model: Type[ModelType]):
23 self.model = model
24 self.db = db
26 def get_db_session(self) -> Type[Any]:
27 return self.db
29 async def insert(
30 self,
31 *,
32 data: Union[ModelType, SchemaType],
33 db_session: Union[AsyncSession, None] = None,
34 ) -> int:
35 db_session = db_session or self.db.session
36 orm_data = self.model.model_validate(data)
37 db_session.add(orm_data)
38 return orm_data
40 async def insert_batch(
41 self, *, data_list: List[Any], db_session: Any = None
42 ) -> int:
43 db_session = db_session or self.db.session
44 orm_datas = [
45 self.model.model_validate(data)
46 if not isinstance(data, self.model)
47 else data
48 for data in data_list
49 ]
50 statement = insert(self.model).values([data.model_dump() for data in orm_datas])
51 await db_session.execute(statement)
52 return len(data_list)
54 async def select_by_id(self, *, id: Any, db_session: Any = None) -> Any:
55 db_session = db_session or self.db.session
56 statement = select(self.model).where(self.model.id == id)
57 response = await db_session.execute(statement)
58 return response.scalar_one_or_none()
60 async def select_by_ids(
61 self, *, ids: List[Any], batch_size: int = 1000, db_session: Any = None
62 ) -> List[Any]:
63 db_session = db_session or self.db.session
64 result_set = []
65 for i in range(0, len(ids), batch_size):
66 batch_ids = ids[i : i + batch_size]
67 statement = select(self.model).where(self.model.id.in_(batch_ids))
68 results = await db_session.exec(statement).all()
69 result_set.extend(results)
70 return result_set
72 async def select_count(self, *, db_session: Any = None) -> int:
73 db_session = db_session or self.db.session
74 response = await db_session.execute(
75 select(func.count()).select_from(select(self.model).subquery())
76 )
77 return response.scalar_one()
79 async def select_list(
80 self, *, page: int = 1, size: int = 100, query: Any, db_session: Any = None
81 ) -> List[Any]:
82 db_session = db_session or self.db.session
83 if query is None:
84 query = (
85 select(self.model)
86 .offset((page - 1) * size)
87 .limit(size)
88 .order_by(self.model.id)
89 )
90 response = await db_session.execute(query)
91 return response.scalars().all()
93 async def select_list_ordered(
94 self,
95 *,
96 page: int = 1,
97 size: int = 100,
98 query: Any,
99 order_by: Any,
100 sort_order: Any,
101 db_session: Any = None,
102 ) -> List[Any]:
103 db_session = db_session or self.db.session
104 columns = self.model.__table__.columns
105 if order_by is None or order_by not in columns:
106 order_by = "id"
107 if sort_order == SortEnum.ascending:
108 query = (
109 select(self.model)
110 .offset((page - 1) * size)
111 .limit(size)
112 .order_by(columns[order_by].asc())
113 )
114 else:
115 query = (
116 select(self.model)
117 .offset((page - 1) * size)
118 .limit(size)
119 .order_by(columns[order_by].desc())
120 )
121 response = await db_session.execute(query)
122 return response.scalars().all()
124 async def select_list_page(
125 self, *, params: Any, query: Any, db_session: Any = None
126 ) -> List[Any]:
127 db_session = db_session or self.db.session
128 if query is None:
129 query = select(self.model)
130 response = await paginate(db_session, query, params)
131 return response
133 async def select_list_page_ordered(
134 self,
135 *,
136 params: Any,
137 query: Any,
138 order_by: Any,
139 sort_order: Any,
140 db_session: Any = None,
141 ) -> List[Any]:
142 db_session = db_session or self.db.session
143 columns = self.model.__table__.columns
144 if order_by is None or order_by not in columns:
145 order_by = "id"
146 if query is None:
147 if sort_order == SortEnum.ascending:
148 query = select(self.model).order_by(columns[order_by].asc())
149 else:
150 query = select(self.model).order_by(columns[order_by].desc())
151 return await paginate(db_session, query, params)
153 async def update_by_id(self, *, data: Any, db_session: Any = None) -> int:
154 db_session = db_session or self.db.session
155 query = select(self.model).where(self.model.id == data.id)
156 result = await db_session.execute(query)
157 if result is None:
158 return 0
159 db_data = result.scalar_one()
160 for attr, value in data.items():
161 setattr(db_data, attr, value)
162 db_session.add(db_data)
163 return self.count_affected_rows(db_data)
165 async def update_batch_by_ids(
166 self, *, data_list: List[Any], db_session: Any = None
167 ) -> int:
168 db_session = db_session or self.db.session
169 for data in data_list:
170 if hasattr(data, "id"):
171 statement = (
172 update(self.model)
173 .where(self.model.id == data.id)
174 .values(**data.dict(exclude_unset=True))
175 )
176 await db_session.execute(statement)
177 return len(data_list)
179 async def delete_by_id(self, *, id: Any, db_session: Any = None) -> int:
180 db_session = db_session or self.db.session
181 response = await db_session.execute(
182 select(self.model).where(self.model.id == id)
183 )
184 data = response.scalar_one()
185 await db_session.delete(data)
186 return 1
188 async def delete_batch_by_ids(
189 self, *, ids: List[Any], db_session: Any = None
190 ) -> int:
191 db_session = db_session or self.db.session
192 statement = delete(self.model).where(self.model.id.in_(ids))
193 result = await db_session.execute(statement)
194 return result.rowcount