Coverage for src/fastoai/routers/beta/threads.py: 57%

21 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-12-06 09:34 +0800

1from fastapi import APIRouter 

2from openai.types.beta.thread_create_params import ThreadCreateParams 

3from openai.types.beta.threads.message_create_params import MessageCreateParams 

4from pydantic import RootModel 

5 

6from ...dependencies import SessionDependency 

7from ...models import Thread 

8from .messages import create_message 

9 

10router = APIRouter() 

11 

12 

13@router.post("/threads") 

14async def create_thread( 

15 *, 

16 params: RootModel[ThreadCreateParams] | None = None, 

17 session: SessionDependency, 

18) -> Thread: 

19 thread = Thread.model_validate( 

20 {} 

21 if params is None 

22 else params.model_dump(exclude={"messages"}, exclude_none=True) 

23 ) 

24 session.add(thread) 

25 await session.commit() 

26 for message in params.root.get("messages", []) if params is not None else []: 

27 await create_message( 

28 thread_id=thread.id, 

29 params=RootModel[MessageCreateParams].model_validate(message), 

30 session=session, 

31 ) 

32 await session.refresh(thread) 

33 return thread 

34 

35 

36@router.get("/threads/{thread_id}") 

37async def retrieve_thread(*, thread_id: str, session: SessionDependency) -> Thread: 

38 thread = await session.get_one(Thread, thread_id) 

39 return thread