Coverage for src/fastoai/routers/beta/messages.py: 54%

24 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-12-06 09:34 +0800

1from typing import cast 

2 

3from fastapi import APIRouter 

4from openai.types.beta.threads.message import Message as OpenAIMessage 

5from openai.types.beta.threads.message_create_params import MessageCreateParams 

6from pydantic import RootModel 

7 

8from ...dependencies import SessionDependency 

9from ...models import Message, Thread 

10from ...pagination import AsyncCursorPage 

11 

12router = APIRouter() 

13 

14 

15@router.post("/threads/{thread_id}/messages") 

16async def create_message( 

17 thread_id: str, 

18 params: RootModel[MessageCreateParams], 

19 session: SessionDependency, 

20): 

21 if isinstance(params.root["content"], str): 

22 content = [ 

23 { 

24 "type": "text", 

25 "text": {"value": params.root["content"], "annotations": []}, 

26 } 

27 ] 

28 else: 

29 content = params.root["content"] 

30 message = Message( # type: ignore 

31 thread_id=thread_id, 

32 attachments=params.root.get("attachments"), 

33 status="completed", 

34 content=content, 

35 role=params.root["role"], 

36 ) 

37 session.add(message) 

38 await session.commit() 

39 await session.refresh(message) 

40 return message 

41 

42 

43@router.get("/threads/{thread_id}/messages") 

44async def list_messages( 

45 thread_id: str, 

46 session: SessionDependency, 

47) -> AsyncCursorPage[OpenAIMessage]: 

48 thread = await session.get_one(Thread, thread_id) 

49 messages = [ 

50 await m.to_openai_model() 

51 for m in cast(list[Message], await thread.awaitable_attrs.messages) 

52 ] 

53 return AsyncCursorPage[OpenAIMessage](data=messages)