Coverage for src/fastoai/routers/beta/runs.py: 25%

80 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-12-06 09:34 +0800

1from asyncio import timeout 

2from datetime import datetime 

3from functools import wraps 

4from typing import cast 

5 

6from fastapi import APIRouter 

7from fastapi.responses import StreamingResponse 

8from openai.types.beta.assistant_stream_event import ( 

9 AssistantStreamEvent, 

10 ErrorEvent, 

11 ThreadMessageCreated, 

12 ThreadMessageDelta, 

13 ThreadMessageInProgress, 

14 ThreadRunCompleted, 

15 ThreadRunCreated, 

16 ThreadRunExpired, 

17 ThreadRunFailed, 

18 ThreadRunInProgress, 

19 ThreadRunQueued, 

20 ThreadRunStepCreated, 

21 ThreadRunStepInProgress, 

22) 

23from openai.types.beta.threads.run import LastError 

24from openai.types.beta.threads.run_create_params import ( 

25 RunCreateParams, 

26 RunCreateParamsStreaming, 

27) 

28from openai.types.beta.threads.runs.message_creation_step_details import ( 

29 MessageCreation, 

30 MessageCreationStepDetails, 

31) 

32from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam 

33from openai.types.shared import ErrorObject 

34from pydantic import RootModel 

35 

36from ...dependencies import OpenAIDependency, SessionDependency 

37from ...models import ( 

38 Assistant, 

39 Message, 

40 Run, 

41 RunStep, 

42 Thread, 

43) 

44 

45 

46def _(event: AssistantStreamEvent): 

47 return f"event: {event.event}\ndata: {event.data.model_dump_json()}\n" 

48 

49 

50def run_decorator(run_model: Run, session: SessionDependency): 

51 def event_decorator(generator_func): 

52 @wraps(generator_func) 

53 async def wrapper(*args, **kwargs): 

54 yield _( 

55 ThreadRunCreated( 

56 data=await run_model.to_openai_model(), 

57 event="thread.run.created", 

58 ) 

59 ) 

60 try: 

61 async with timeout( 

62 None 

63 if run_model.expires_at is None 

64 else (run_model.expires_at - datetime.now()).total_seconds() 

65 ): 

66 async for value in generator_func(*args, **kwargs): 

67 yield value 

68 

69 run_model.status = "completed" 

70 session.add(run_model) 

71 await session.commit() 

72 yield _( 

73 ThreadRunCompleted( 

74 data=await run_model.to_openai_model(), 

75 event="thread.run.completed", 

76 ) 

77 ) 

78 yield "event: done\ndata: [DONE]\n" 

79 except TimeoutError: 

80 run_model.status = "expired" 

81 session.add(run_model) 

82 await session.commit() 

83 yield _( 

84 ThreadRunExpired( 

85 data=await run_model.to_openai_model(), 

86 event="thread.run.expired", 

87 ) 

88 ) 

89 yield _( 

90 ErrorEvent( 

91 data=ErrorObject(message="Run expired", type="TimeoutError"), 

92 event="error", 

93 ) 

94 ) 

95 except Exception as e: 

96 run_model.status = "failed" 

97 run_model.last_error = LastError(code="server_error", message=str(e)) 

98 session.add(run_model) 

99 await session.commit() 

100 yield _( 

101 ThreadRunFailed( 

102 data=await run_model.to_openai_model(), 

103 event="thread.run.failed", 

104 ) 

105 ) 

106 

107 return wrapper 

108 

109 return event_decorator 

110 

111 

112router = APIRouter() 

113 

114 

115@router.post("/threads/{thread_id}/runs") 

116async def create_run( 

117 thread_id: str, 

118 params: RootModel[RunCreateParams], 

119 session: SessionDependency, 

120 client: OpenAIDependency, 

121): 

122 if not params.root.get("stream", False): 

123 raise NotImplementedError("Non-streaming is not yet supported") 

124 run_params = cast(RunCreateParamsStreaming, params.model_dump()) 

125 assistant = await session.get_one(Assistant, run_params["assistant_id"]) 

126 thread = await session.get_one(Thread, thread_id) 

127 run = Run( # type: ignore 

128 assistant=assistant, 

129 thread=thread, 

130 status="queued", 

131 model=run_params.get("model") or assistant.model, 

132 instructions=run_params.get("instructions") or assistant.instructions or "", 

133 parallel_tool_calls=run_params.get("parallel_tool_calls", True), 

134 tools=run_params.get("tools") or assistant.tools, 

135 ) 

136 messages: list[ChatCompletionMessageParam] = [ 

137 {"role": "system", "content": run.instructions}, 

138 *[m.model_dump() for m in thread.messages], # type: ignore 

139 ] 

140 

141 async def message_creation_step(): 

142 message = Message( # type: ignore 

143 thread=thread, 

144 assistant=assistant, 

145 run=run, 

146 content=[], 

147 role="assistant", 

148 status="in_progress", 

149 ) 

150 session.add(message) 

151 step = RunStep( # type: ignore 

152 run=run, 

153 thread=thread, 

154 assistant=assistant, 

155 status="in_progress", 

156 type="message_creation", 

157 step_details=MessageCreationStepDetails( 

158 message_creation=MessageCreation(message_id=message.id), 

159 type="message_creation", 

160 ), 

161 ) 

162 yield _( 

163 ThreadRunStepCreated( 

164 data=await step.to_openai_model(), 

165 event="thread.run.step.created", 

166 ) 

167 ) 

168 session.add(step) 

169 await session.commit() 

170 await session.refresh(step) 

171 yield _( 

172 ThreadRunStepInProgress( 

173 data=await step.to_openai_model(), 

174 event="thread.run.step.in_progress", 

175 ) 

176 ) 

177 yield _( 

178 ThreadMessageCreated( 

179 data=await message.to_openai_model(), 

180 event="thread.message.created", 

181 ) 

182 ) 

183 await session.commit() 

184 await session.refresh(message) 

185 yield _( 

186 ThreadMessageInProgress( 

187 data=await message.to_openai_model(), event="thread.message.in_progress" 

188 ) 

189 ) 

190 async for part in await client.chat.completions.create( 

191 model=assistant.model, 

192 messages=messages, 

193 stream=True, 

194 ): 

195 yield _( 

196 ThreadMessageDelta.model_validate( 

197 dict( 

198 event="thread.message.delta", 

199 data=dict( 

200 id=message.id, 

201 delta=dict( 

202 content=[ 

203 dict( 

204 index=0, 

205 type="text", 

206 text=dict( 

207 value=part.choices[0].delta.content, 

208 annotations=[], 

209 ), 

210 ) 

211 ], 

212 role="assistant", 

213 ), 

214 object="thread.message.delta", 

215 ), 

216 ) 

217 ) 

218 ) 

219 

220 @run_decorator(run, session) 

221 async def xrun(): 

222 session.add(run) 

223 await session.commit() 

224 await session.refresh(run) 

225 yield _( 

226 ThreadRunQueued( 

227 event="thread.run.queued", 

228 data=await run.to_openai_model(), 

229 ) 

230 ) 

231 

232 yield _( 

233 ThreadRunInProgress( 

234 event="thread.run.in_progress", 

235 data=await run.to_openai_model(), 

236 ) 

237 ) 

238 async for message in message_creation_step(): 

239 yield message 

240 

241 return StreamingResponse(xrun())