Coverage for src/fastoai/routers/beta/assistants.py: 55%
56 statements
« prev ^ index » next coverage.py v7.6.8, created at 2024-12-06 09:34 +0800
« prev ^ index » next coverage.py v7.6.8, created at 2024-12-06 09:34 +0800
1from typing import Annotated, cast
3from fastapi import APIRouter
4from openai.types.beta.assistant import Assistant as _Assistant
5from openai.types.beta.assistant_create_params import AssistantCreateParams
6from openai.types.beta.assistant_deleted import AssistantDeleted
7from openai.types.beta.assistant_update_params import AssistantUpdateParams
8from pydantic import Field, RootModel
9from sqlmodel import col, select
11from ...dependencies import SessionDependency
12from ...models import Assistant
13from ...pagination import AsyncCursorPage
14from .._types import Order
16router = APIRouter()
19@router.post("/assistants", response_model=_Assistant)
20async def create_assistant(
21 params: RootModel[AssistantCreateParams],
22 session: SessionDependency,
23) -> _Assistant:
24 assistant = Assistant.model_validate(params.model_dump())
25 session.add(assistant)
26 await session.commit()
27 await session.refresh(assistant)
28 return await assistant.to_openai_model()
31@router.get("/assistants", response_model=AsyncCursorPage[_Assistant])
32async def list_assistants(
33 *,
34 limit: Annotated[int, Field(ge=1, le=100)] = 20,
35 order: Order = "desc",
36 after: str | None = None,
37 before: str | None = None,
38 session: SessionDependency,
39) -> AsyncCursorPage[_Assistant]:
40 statement = select(Assistant).order_by(getattr(col(Assistant.created_at), order)())
41 if after is not None:
42 after_assistant = await session.get_one(Assistant, after)
43 statement = statement.where(
44 Assistant.created_at < after_assistant.created_at
45 if order == "desc"
46 else Assistant.created_at > after_assistant.created_at
47 )
48 if before is not None:
49 before_assistant = await session.get_one(Assistant, before)
50 statement = statement.where(
51 Assistant.created_at > before_assistant.created_at
52 if order == "desc"
53 else Assistant.created_at < before_assistant.created_at
54 )
55 assistants = list((await session.exec(statement.limit(limit))).all())
56 kwargs = {}
57 if len(assistants) > 0:
58 kwargs["first_id"] = assistants[0].id
59 kwargs["last_id"] = assistants[-1].id
60 if len(assistants) == limit:
61 after_assistant = await session.get_one(Assistant, assistants[-1].id)
62 statement = (
63 select(Assistant)
64 .order_by(getattr(col(Assistant.created_at), order)())
65 .where(
66 Assistant.created_at < after_assistant.created_at
67 if order == "desc"
68 else Assistant.created_at > after_assistant.created_at
69 )
70 )
71 return AsyncCursorPage[_Assistant](
72 data=[await a.to_openai_model() for a in assistants], **kwargs
73 )
76@router.get("/assistants/{assistant_id}", response_model=_Assistant)
77async def retrieve_assistant(
78 assistant_id: str,
79 session: SessionDependency,
80) -> _Assistant:
81 assistant = await session.get_one(Assistant, assistant_id)
82 return await assistant.to_openai_model()
85@router.post("/assistants/{assistant_id}", response_model=_Assistant)
86async def update_assistant(
87 assistant_id: str,
88 params: RootModel[AssistantUpdateParams],
89 session: SessionDependency,
90) -> _Assistant:
91 assistant = await session.get_one(Assistant, assistant_id)
92 obj = cast(AssistantUpdateParams, params.model_dump(exclude_unset=True))
93 for k, v in obj.items():
94 setattr(assistant, k if k != "metadata" else "metadata_", v)
95 await session.commit()
96 return await assistant.to_openai_model()
99@router.delete("/assistants/{assistant_id}", response_model=AssistantDeleted)
100async def delete_assistant(
101 assistant_id: str,
102 session: SessionDependency,
103) -> AssistantDeleted:
104 assistant = await session.get_one(Assistant, assistant_id)
105 await session.delete(assistant)
106 await session.commit()
107 return AssistantDeleted(id=assistant_id, deleted=True, object="assistant.deleted")