Coverage for src/fastoai/routers/beta/runs.py: 25%
80 statements
« prev ^ index » next coverage.py v7.6.8, created at 2024-12-06 09:34 +0800
« prev ^ index » next coverage.py v7.6.8, created at 2024-12-06 09:34 +0800
1from asyncio import timeout
2from datetime import datetime
3from functools import wraps
4from typing import cast
6from fastapi import APIRouter
7from fastapi.responses import StreamingResponse
8from openai.types.beta.assistant_stream_event import (
9 AssistantStreamEvent,
10 ErrorEvent,
11 ThreadMessageCreated,
12 ThreadMessageDelta,
13 ThreadMessageInProgress,
14 ThreadRunCompleted,
15 ThreadRunCreated,
16 ThreadRunExpired,
17 ThreadRunFailed,
18 ThreadRunInProgress,
19 ThreadRunQueued,
20 ThreadRunStepCreated,
21 ThreadRunStepInProgress,
22)
23from openai.types.beta.threads.run import LastError
24from openai.types.beta.threads.run_create_params import (
25 RunCreateParams,
26 RunCreateParamsStreaming,
27)
28from openai.types.beta.threads.runs.message_creation_step_details import (
29 MessageCreation,
30 MessageCreationStepDetails,
31)
32from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
33from openai.types.shared import ErrorObject
34from pydantic import RootModel
36from ...dependencies import OpenAIDependency, SessionDependency
37from ...models import (
38 Assistant,
39 Message,
40 Run,
41 RunStep,
42 Thread,
43)
46def _(event: AssistantStreamEvent):
47 return f"event: {event.event}\ndata: {event.data.model_dump_json()}\n"
50def run_decorator(run_model: Run, session: SessionDependency):
51 def event_decorator(generator_func):
52 @wraps(generator_func)
53 async def wrapper(*args, **kwargs):
54 yield _(
55 ThreadRunCreated(
56 data=await run_model.to_openai_model(),
57 event="thread.run.created",
58 )
59 )
60 try:
61 async with timeout(
62 None
63 if run_model.expires_at is None
64 else (run_model.expires_at - datetime.now()).total_seconds()
65 ):
66 async for value in generator_func(*args, **kwargs):
67 yield value
69 run_model.status = "completed"
70 session.add(run_model)
71 await session.commit()
72 yield _(
73 ThreadRunCompleted(
74 data=await run_model.to_openai_model(),
75 event="thread.run.completed",
76 )
77 )
78 yield "event: done\ndata: [DONE]\n"
79 except TimeoutError:
80 run_model.status = "expired"
81 session.add(run_model)
82 await session.commit()
83 yield _(
84 ThreadRunExpired(
85 data=await run_model.to_openai_model(),
86 event="thread.run.expired",
87 )
88 )
89 yield _(
90 ErrorEvent(
91 data=ErrorObject(message="Run expired", type="TimeoutError"),
92 event="error",
93 )
94 )
95 except Exception as e:
96 run_model.status = "failed"
97 run_model.last_error = LastError(code="server_error", message=str(e))
98 session.add(run_model)
99 await session.commit()
100 yield _(
101 ThreadRunFailed(
102 data=await run_model.to_openai_model(),
103 event="thread.run.failed",
104 )
105 )
107 return wrapper
109 return event_decorator
112router = APIRouter()
115@router.post("/threads/{thread_id}/runs")
116async def create_run(
117 thread_id: str,
118 params: RootModel[RunCreateParams],
119 session: SessionDependency,
120 client: OpenAIDependency,
121):
122 if not params.root.get("stream", False):
123 raise NotImplementedError("Non-streaming is not yet supported")
124 run_params = cast(RunCreateParamsStreaming, params.model_dump())
125 assistant = await session.get_one(Assistant, run_params["assistant_id"])
126 thread = await session.get_one(Thread, thread_id)
127 run = Run( # type: ignore
128 assistant=assistant,
129 thread=thread,
130 status="queued",
131 model=run_params.get("model") or assistant.model,
132 instructions=run_params.get("instructions") or assistant.instructions or "",
133 parallel_tool_calls=run_params.get("parallel_tool_calls", True),
134 tools=run_params.get("tools") or assistant.tools,
135 )
136 messages: list[ChatCompletionMessageParam] = [
137 {"role": "system", "content": run.instructions},
138 *[m.model_dump() for m in thread.messages], # type: ignore
139 ]
141 async def message_creation_step():
142 message = Message( # type: ignore
143 thread=thread,
144 assistant=assistant,
145 run=run,
146 content=[],
147 role="assistant",
148 status="in_progress",
149 )
150 session.add(message)
151 step = RunStep( # type: ignore
152 run=run,
153 thread=thread,
154 assistant=assistant,
155 status="in_progress",
156 type="message_creation",
157 step_details=MessageCreationStepDetails(
158 message_creation=MessageCreation(message_id=message.id),
159 type="message_creation",
160 ),
161 )
162 yield _(
163 ThreadRunStepCreated(
164 data=await step.to_openai_model(),
165 event="thread.run.step.created",
166 )
167 )
168 session.add(step)
169 await session.commit()
170 await session.refresh(step)
171 yield _(
172 ThreadRunStepInProgress(
173 data=await step.to_openai_model(),
174 event="thread.run.step.in_progress",
175 )
176 )
177 yield _(
178 ThreadMessageCreated(
179 data=await message.to_openai_model(),
180 event="thread.message.created",
181 )
182 )
183 await session.commit()
184 await session.refresh(message)
185 yield _(
186 ThreadMessageInProgress(
187 data=await message.to_openai_model(), event="thread.message.in_progress"
188 )
189 )
190 async for part in await client.chat.completions.create(
191 model=assistant.model,
192 messages=messages,
193 stream=True,
194 ):
195 yield _(
196 ThreadMessageDelta.model_validate(
197 dict(
198 event="thread.message.delta",
199 data=dict(
200 id=message.id,
201 delta=dict(
202 content=[
203 dict(
204 index=0,
205 type="text",
206 text=dict(
207 value=part.choices[0].delta.content,
208 annotations=[],
209 ),
210 )
211 ],
212 role="assistant",
213 ),
214 object="thread.message.delta",
215 ),
216 )
217 )
218 )
220 @run_decorator(run, session)
221 async def xrun():
222 session.add(run)
223 await session.commit()
224 await session.refresh(run)
225 yield _(
226 ThreadRunQueued(
227 event="thread.run.queued",
228 data=await run.to_openai_model(),
229 )
230 )
232 yield _(
233 ThreadRunInProgress(
234 event="thread.run.in_progress",
235 data=await run.to_openai_model(),
236 )
237 )
238 async for message in message_creation_step():
239 yield message
241 return StreamingResponse(xrun())