Coverage for fss\middleware\db_session_middleware.py: 82%

68 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-11 19:09 +0800

1"""Session proxy used in the project""" 

2 

3from contextvars import ContextVar 

4from typing import Dict, Optional, Union 

5 

6from sqlalchemy.engine import Engine 

7from sqlalchemy.engine.url import URL 

8from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine 

9from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint 

10from starlette.requests import Request 

11from starlette.types import ASGIApp 

12 

13 

14try: 

15 from sqlalchemy.ext.asyncio import async_sessionmaker 

16except ImportError: 

17 from sqlalchemy.orm import sessionmaker as async_sessionmaker 

18 

19 

20def create_middleware_and_session_proxy(): 

21 _Session: Optional[async_sessionmaker] = None 

22 # Usage of context vars inside closures is not recommended, since they are not properly 

23 # garbage collected, but in our use case context var is created on program startup and 

24 # is used throughout the whole its lifecycle. 

25 _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) 

26 

27 class SQLAlchemyMiddleware(BaseHTTPMiddleware): 

28 def __init__( 

29 self, 

30 app: ASGIApp, 

31 db_url: Optional[Union[str, URL]] = None, 

32 custom_engine: Optional[Engine] = None, 

33 engine_args: Dict = None, 

34 session_args: Dict = None, 

35 commit_on_exit: bool = True, 

36 ): 

37 super().__init__(app) 

38 self.commit_on_exit = commit_on_exit 

39 engine_args = engine_args or {} 

40 session_args = session_args or {} 

41 

42 if not custom_engine and not db_url: 

43 raise ValueError( 

44 "You need to pass a db_url or a custom_engine parameter." 

45 ) 

46 if not custom_engine: 

47 engine = create_async_engine(db_url, **engine_args) 

48 else: 

49 engine = custom_engine 

50 

51 nonlocal _Session 

52 _Session = async_sessionmaker( 

53 engine, class_=AsyncSession, expire_on_commit=False, **session_args 

54 ) 

55 

56 async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): 

57 async with DBSession(commit_on_exit=self.commit_on_exit): 

58 return await call_next(request) 

59 

60 class DBSessionMeta(type): 

61 @property 

62 def session(self) -> AsyncSession: 

63 """Return an instance of Session local to the current async context.""" 

64 if _Session is None: 

65 raise SessionNotInitialisedException 

66 

67 session = _session.get() 

68 if session is None: 

69 raise MissingSessionException 

70 

71 return session 

72 

73 class DBSession(metaclass=DBSessionMeta): 

74 def __init__(self, session_args: Dict = None, commit_on_exit: bool = False): 

75 self.token = None 

76 self.session_args = session_args or {} 

77 self.commit_on_exit = commit_on_exit 

78 

79 async def __aenter__(self): 

80 if not isinstance(_Session, async_sessionmaker): 

81 raise SessionNotInitialisedException 

82 

83 self.token = _session.set(_Session(**self.session_args)) # type: ignore 

84 return type(self) 

85 

86 async def __aexit__(self, exc_type, exc_value, traceback): 

87 session = _session.get() 

88 

89 try: 

90 if exc_type is not None: 

91 await session.rollback() 

92 elif ( 

93 self.commit_on_exit 

94 ): # Note: Changed this to elif to avoid commit after rollback 

95 await session.commit() 

96 finally: 

97 await session.close() 

98 _session.reset(self.token) 

99 

100 return SQLAlchemyMiddleware, DBSession 

101 

102 

103SQLAlchemyMiddleware, db = create_middleware_and_session_proxy() 

104 

105 

106class MissingSessionException(Exception): 

107 """ 

108 Exception raised for when the user tries to access a database session before it is created. 

109 """ 

110 

111 def __init__(self): 

112 detail = """ 

113 No session found! Either you are not currently in a request context, 

114 or you need to manually create a session context by using a `db` instance as 

115 a context manager e.g.: 

116 

117 async with db(): 

118 await db.session.execute(foo.select()).fetchall() 

119 """ 

120 

121 super().__init__(detail) 

122 

123 

124class SessionNotInitialisedException(Exception): 

125 """ 

126 Exception raised when the user creates a new DB session without first initialising it. 

127 """ 

128 

129 def __init__(self): 

130 detail = """ 

131 Session not initialised! Ensure that DBSessionMiddleware has been initialised before 

132 attempting database access. 

133 """ 

134 

135 super().__init__(detail)