Coverage for src/fastoai/routers/beta/assistants.py: 55%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-12-06 09:34 +0800

1from typing import Annotated, cast 

2 

3from fastapi import APIRouter 

4from openai.types.beta.assistant import Assistant as _Assistant 

5from openai.types.beta.assistant_create_params import AssistantCreateParams 

6from openai.types.beta.assistant_deleted import AssistantDeleted 

7from openai.types.beta.assistant_update_params import AssistantUpdateParams 

8from pydantic import Field, RootModel 

9from sqlmodel import col, select 

10 

11from ...dependencies import SessionDependency 

12from ...models import Assistant 

13from ...pagination import AsyncCursorPage 

14from .._types import Order 

15 

16router = APIRouter() 

17 

18 

19@router.post("/assistants", response_model=_Assistant) 

20async def create_assistant( 

21 params: RootModel[AssistantCreateParams], 

22 session: SessionDependency, 

23) -> _Assistant: 

24 assistant = Assistant.model_validate(params.model_dump()) 

25 session.add(assistant) 

26 await session.commit() 

27 await session.refresh(assistant) 

28 return await assistant.to_openai_model() 

29 

30 

31@router.get("/assistants", response_model=AsyncCursorPage[_Assistant]) 

32async def list_assistants( 

33 *, 

34 limit: Annotated[int, Field(ge=1, le=100)] = 20, 

35 order: Order = "desc", 

36 after: str | None = None, 

37 before: str | None = None, 

38 session: SessionDependency, 

39) -> AsyncCursorPage[_Assistant]: 

40 statement = select(Assistant).order_by(getattr(col(Assistant.created_at), order)()) 

41 if after is not None: 

42 after_assistant = await session.get_one(Assistant, after) 

43 statement = statement.where( 

44 Assistant.created_at < after_assistant.created_at 

45 if order == "desc" 

46 else Assistant.created_at > after_assistant.created_at 

47 ) 

48 if before is not None: 

49 before_assistant = await session.get_one(Assistant, before) 

50 statement = statement.where( 

51 Assistant.created_at > before_assistant.created_at 

52 if order == "desc" 

53 else Assistant.created_at < before_assistant.created_at 

54 ) 

55 assistants = list((await session.exec(statement.limit(limit))).all()) 

56 kwargs = {} 

57 if len(assistants) > 0: 

58 kwargs["first_id"] = assistants[0].id 

59 kwargs["last_id"] = assistants[-1].id 

60 if len(assistants) == limit: 

61 after_assistant = await session.get_one(Assistant, assistants[-1].id) 

62 statement = ( 

63 select(Assistant) 

64 .order_by(getattr(col(Assistant.created_at), order)()) 

65 .where( 

66 Assistant.created_at < after_assistant.created_at 

67 if order == "desc" 

68 else Assistant.created_at > after_assistant.created_at 

69 ) 

70 ) 

71 return AsyncCursorPage[_Assistant]( 

72 data=[await a.to_openai_model() for a in assistants], **kwargs 

73 ) 

74 

75 

76@router.get("/assistants/{assistant_id}", response_model=_Assistant) 

77async def retrieve_assistant( 

78 assistant_id: str, 

79 session: SessionDependency, 

80) -> _Assistant: 

81 assistant = await session.get_one(Assistant, assistant_id) 

82 return await assistant.to_openai_model() 

83 

84 

85@router.post("/assistants/{assistant_id}", response_model=_Assistant) 

86async def update_assistant( 

87 assistant_id: str, 

88 params: RootModel[AssistantUpdateParams], 

89 session: SessionDependency, 

90) -> _Assistant: 

91 assistant = await session.get_one(Assistant, assistant_id) 

92 obj = cast(AssistantUpdateParams, params.model_dump(exclude_unset=True)) 

93 for k, v in obj.items(): 

94 setattr(assistant, k if k != "metadata" else "metadata_", v) 

95 await session.commit() 

96 return await assistant.to_openai_model() 

97 

98 

99@router.delete("/assistants/{assistant_id}", response_model=AssistantDeleted) 

100async def delete_assistant( 

101 assistant_id: str, 

102 session: SessionDependency, 

103) -> AssistantDeleted: 

104 assistant = await session.get_one(Assistant, assistant_id) 

105 await session.delete(assistant) 

106 await session.commit() 

107 return AssistantDeleted(id=assistant_id, deleted=True, object="assistant.deleted")