Coverage for src/fastoai/dependencies.py: 62%

32 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-12-06 09:34 +0800

1from functools import lru_cache 

2from typing import Annotated 

3 

4from fastapi import Depends, HTTPException, status 

5from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer 

6from openai import AsyncOpenAI 

7from sqlalchemy.ext.asyncio import create_async_engine 

8from sqlmodel import SQLModel 

9from sqlmodel.ext.asyncio.session import AsyncSession 

10 

11from .models import APIKey, User 

12from .settings import Settings, get_settings 

13 

14SettingsDependency = Annotated[Settings, Depends(get_settings)] 

15 

16 

17async def get_session(settings: SettingsDependency): 

18 """Get session.""" 

19 engine = create_async_engine(settings.database_url) 

20 async with engine.begin() as conn: 

21 await conn.run_sync(SQLModel.metadata.create_all) 

22 async with AsyncSession(engine) as session: 

23 yield session 

24 

25 

26SessionDependency = Annotated[AsyncSession, Depends(get_session)] 

27 

28 

29@lru_cache 

30def get_openai(settings: SettingsDependency): 

31 """Get OpenAI client.""" 

32 return AsyncOpenAI(**settings.openai.model_dump()) 

33 

34 

35OpenAIDependency = Annotated[AsyncOpenAI, Depends(get_openai)] 

36 

37security = HTTPBearer() 

38 

39 

40async def get_user( 

41 *, 

42 credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)], 

43 session: SessionDependency, 

44) -> User: 

45 """Get the current user.""" 

46 api_key = await session.get(APIKey, credentials.credentials) 

47 if api_key is None: 

48 raise HTTPException( 

49 status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key" 

50 ) 

51 await session.refresh(api_key, ["user"]) 

52 if not api_key.user.is_active: 

53 raise HTTPException( 

54 status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user" 

55 ) 

56 return api_key.user 

57 

58 

59UserDependency = Annotated[User, Depends(get_user)]