Coverage for src\agents_sdk_models\context.py: 96%

112 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-04 17:38 +0900

1from __future__ import annotations 

2 

3"""Context — Shared state management for Flow/Step workflows. 

4 

5Contextはフロー/ステップワークフロー用の共有状態管理を提供します。 

6型安全で読みやすく、LangChain LCELとの互換性も持ちます。 

7""" 

8 

9import asyncio 

10from typing import Any, Dict, List, Optional, Union 

11from datetime import datetime 

12 

13try: 

14 from pydantic import BaseModel, Field, PrivateAttr # type: ignore 

15except ImportError: 

16 BaseModel = object # type: ignore 

17 Field = lambda **kwargs: None # type: ignore 

18 PrivateAttr = lambda **kwargs: None # type: ignore 

19 

20 

21class Message(BaseModel): 

22 """ 

23 Message class for conversation history 

24 会話履歴用メッセージクラス 

25  

26 Attributes: 

27 role: Message role (user, assistant, system) / メッセージの役割 

28 content: Message content / メッセージ内容 

29 timestamp: Message timestamp / メッセージのタイムスタンプ 

30 metadata: Additional metadata / 追加メタデータ 

31 """ 

32 role: str # Message role (user, assistant, system) / メッセージの役割 

33 content: str # Message content / メッセージ内容 

34 timestamp: datetime = Field(default_factory=datetime.now) # Message timestamp / メッセージのタイムスタンプ 

35 metadata: Dict[str, Any] = Field(default_factory=dict) # Additional metadata / 追加メタデータ 

36 

37 

38class Context(BaseModel): 

39 """ 

40 Context class for Flow/Step workflow state management 

41 フロー/ステップワークフロー状態管理用コンテキストクラス 

42  

43 This class provides: 

44 このクラスは以下を提供します: 

45 - Type-safe shared state / 型安全な共有状態 

46 - Conversation history management / 会話履歴管理 

47 - Step routing control / ステップルーティング制御 

48 - LangChain LCEL compatibility / LangChain LCEL互換性 

49 - User input/output coordination / ユーザー入出力調整 

50 """ 

51 

52 # Core state / コア状態 

53 last_user_input: Optional[str] = None # Most recent user input / 直近のユーザー入力 

54 messages: List[Message] = Field(default_factory=list) # Conversation history / 会話履歴 

55 

56 # External data / 外部データ 

57 knowledge: Dict[str, Any] = Field(default_factory=dict) # External knowledge (RAG, etc.) / 外部知識(RAGなど) 

58 prev_outputs: Dict[str, Any] = Field(default_factory=dict) # Previous step outputs / 前ステップの出力 

59 

60 # Flow control / フロー制御 

61 next_label: Optional[str] = None # Next step routing instruction / 次ステップのルーティング指示 

62 current_step: Optional[str] = None # Current step name / 現在のステップ名 

63 

64 # Results / 結果 

65 artifacts: Dict[str, Any] = Field(default_factory=dict) # Flow-wide artifacts / フロー全体の成果物 

66 shared_state: Dict[str, Any] = Field(default_factory=dict) # Arbitrary shared values / 任意の共有値 

67 

68 # User interaction / ユーザー対話 

69 awaiting_prompt: Optional[str] = None # Prompt waiting for user input / ユーザー入力待ちのプロンプト 

70 awaiting_user_input: bool = False # Flag indicating waiting for user input / ユーザー入力待ちフラグ 

71 

72 # Execution metadata / 実行メタデータ 

73 trace_id: Optional[str] = None # Trace ID for observability / オブザーバビリティ用トレースID 

74 start_time: datetime = Field(default_factory=datetime.now) # Flow start time / フロー開始時刻 

75 step_count: int = 0 # Number of steps executed / 実行されたステップ数 

76 

77 # Internal async coordination (private attributes) / 内部非同期調整(プライベート属性) 

78 _user_input_event: Optional[asyncio.Event] = PrivateAttr(default=None) 

79 _awaiting_prompt_event: Optional[asyncio.Event] = PrivateAttr(default=None) 

80 

81 def __init__(self, **data): 

82 """ 

83 Initialize Context with async events 

84 非同期イベントでContextを初期化 

85 """ 

86 super().__init__(**data) 

87 self._user_input_event = asyncio.Event() 

88 self._awaiting_prompt_event = asyncio.Event() 

89 

90 def add_user_message(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> None: 

91 """ 

92 Add user message to conversation history 

93 ユーザーメッセージを会話履歴に追加 

94  

95 Args: 

96 content: Message content / メッセージ内容 

97 metadata: Additional metadata / 追加メタデータ 

98 """ 

99 message = Message( 

100 role="user", 

101 content=content, 

102 metadata=metadata or {} 

103 ) 

104 self.messages.append(message) 

105 self.last_user_input = content 

106 

107 def add_assistant_message(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> None: 

108 """ 

109 Add assistant message to conversation history 

110 アシスタントメッセージを会話履歴に追加 

111  

112 Args: 

113 content: Message content / メッセージ内容 

114 metadata: Additional metadata / 追加メタデータ 

115 """ 

116 message = Message( 

117 role="assistant", 

118 content=content, 

119 metadata=metadata or {} 

120 ) 

121 self.messages.append(message) 

122 

123 def add_system_message(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> None: 

124 """ 

125 Add system message to conversation history 

126 システムメッセージを会話履歴に追加 

127  

128 Args: 

129 content: Message content / メッセージ内容 

130 metadata: Additional metadata / 追加メタデータ 

131 """ 

132 message = Message( 

133 role="system", 

134 content=content, 

135 metadata=metadata or {} 

136 ) 

137 self.messages.append(message) 

138 

139 def set_waiting_for_user_input(self, prompt: str) -> None: 

140 """ 

141 Set context to wait for user input with a prompt 

142 プロンプトでユーザー入力待ち状態に設定 

143  

144 Args: 

145 prompt: Prompt to display to user / ユーザーに表示するプロンプト 

146 """ 

147 self.awaiting_prompt = prompt 

148 self.awaiting_user_input = True 

149 if self._awaiting_prompt_event: 

150 self._awaiting_prompt_event.set() 

151 

152 def provide_user_input(self, user_input: str) -> None: 

153 """ 

154 Provide user input and clear waiting state 

155 ユーザー入力を提供し、待ち状態をクリア 

156  

157 Args: 

158 user_input: User input text / ユーザー入力テキスト 

159 """ 

160 self.add_user_message(user_input) 

161 self.awaiting_prompt = None 

162 self.awaiting_user_input = False 

163 if self._user_input_event: 

164 self._user_input_event.set() 

165 

166 def clear_prompt(self) -> Optional[str]: 

167 """ 

168 Clear and return the current prompt 

169 現在のプロンプトをクリアして返す 

170  

171 Returns: 

172 str | None: The prompt if one was waiting / 待機中だったプロンプト 

173 """ 

174 prompt = self.awaiting_prompt 

175 self.awaiting_prompt = None 

176 if self._awaiting_prompt_event: 

177 self._awaiting_prompt_event.clear() 

178 return prompt 

179 

180 async def wait_for_user_input(self) -> str: 

181 """ 

182 Async wait for user input 

183 ユーザー入力を非同期で待機 

184  

185 Returns: 

186 str: User input / ユーザー入力 

187 """ 

188 if self._user_input_event: 

189 await self._user_input_event.wait() 

190 self._user_input_event.clear() 

191 return self.last_user_input or "" 

192 

193 async def wait_for_prompt_event(self) -> str: 

194 """ 

195 Async wait for prompt event 

196 プロンプトイベントを非同期で待機 

197  

198 Returns: 

199 str: Prompt waiting for user / ユーザー待ちのプロンプト 

200 """ 

201 if self._awaiting_prompt_event: 

202 await self._awaiting_prompt_event.wait() 

203 return self.awaiting_prompt or "" 

204 

205 def goto(self, label: str) -> None: 

206 """ 

207 Set next step routing 

208 次ステップのルーティングを設定 

209  

210 Args: 

211 label: Next step label / 次ステップのラベル 

212 """ 

213 self.next_label = label 

214 

215 def finish(self) -> None: 

216 """ 

217 Mark flow as finished 

218 フローを完了としてマーク 

219 """ 

220 self.next_label = None 

221 

222 def is_finished(self) -> bool: 

223 """ 

224 Check if flow is finished 

225 フローが完了しているかチェック 

226  

227 Returns: 

228 bool: True if finished / 完了している場合True 

229 """ 

230 return self.next_label is None 

231 

232 def as_dict(self) -> Dict[str, Any]: 

233 """ 

234 Convert to dictionary for LangChain LCEL compatibility 

235 LangChain LCEL互換性のために辞書に変換 

236  

237 Returns: 

238 Dict[str, Any]: Dictionary representation / 辞書表現 

239 """ 

240 data = self.dict() 

241 # Convert messages to LangChain format 

242 # メッセージをLangChain形式に変換 

243 data["history"] = [ 

244 {"role": msg.role, "content": msg.content, "metadata": msg.metadata} 

245 for msg in self.messages 

246 ] 

247 data.pop("messages", None) 

248 return data 

249 

250 @classmethod 

251 def from_dict(cls, data: Dict[str, Any]) -> "Context": 

252 """ 

253 Create Context from dictionary (LangChain LCEL compatibility) 

254 辞書からContextを作成(LangChain LCEL互換性) 

255  

256 Args: 

257 data: Dictionary data / 辞書データ 

258  

259 Returns: 

260 Context: New context instance / 新しいコンテキストインスタンス 

261 """ 

262 data = data.copy() 

263 # Convert history to messages 

264 # 履歴をメッセージに変換 

265 history = data.pop("history", []) 

266 messages = [] 

267 for msg_data in history: 

268 if isinstance(msg_data, dict): 

269 messages.append(Message( 

270 role=msg_data.get("role", "user"), 

271 content=msg_data.get("content", ""), 

272 metadata=msg_data.get("metadata", {}) 

273 )) 

274 data["messages"] = messages 

275 return cls(**data) 

276 

277 def get_conversation_text(self, include_system: bool = False) -> str: 

278 """ 

279 Get conversation as formatted text 

280 会話をフォーマット済みテキストとして取得 

281  

282 Args: 

283 include_system: Include system messages / システムメッセージを含める 

284  

285 Returns: 

286 str: Formatted conversation / フォーマット済み会話 

287 """ 

288 lines = [] 

289 for msg in self.messages: 

290 if not include_system and msg.role == "system": 

291 continue 

292 role_label = {"user": "👤", "assistant": "🤖", "system": "⚙️"}.get(msg.role, msg.role) 

293 lines.append(f"{role_label} {msg.content}") 

294 return "\n".join(lines) 

295 

296 def get_last_messages(self, n: int = 10) -> List[Message]: 

297 """ 

298 Get last N messages 

299 最後のNメッセージを取得 

300  

301 Args: 

302 n: Number of messages / メッセージ数 

303  

304 Returns: 

305 List[Message]: Last N messages / 最後のNメッセージ 

306 """ 

307 return self.messages[-n:] if len(self.messages) > n else self.messages.copy() 

308 

309 def update_step_info(self, step_name: str) -> None: 

310 """ 

311 Update current step information 

312 現在のステップ情報を更新 

313  

314 Args: 

315 step_name: Current step name / 現在のステップ名 

316 """ 

317 self.current_step = step_name 

318 self.step_count += 1 

319 

320 def set_artifact(self, key: str, value: Any) -> None: 

321 """ 

322 Set artifact value 

323 成果物の値を設定 

324  

325 Args: 

326 key: Artifact key / 成果物キー 

327 value: Artifact value / 成果物値 

328 """ 

329 self.artifacts[key] = value 

330 

331 def get_artifact(self, key: str, default: Any = None) -> Any: 

332 """ 

333 Get artifact value 

334 成果物の値を取得 

335  

336 Args: 

337 key: Artifact key / 成果物キー 

338 default: Default value if not found / 見つからない場合のデフォルト値 

339  

340 Returns: 

341 Any: Artifact value / 成果物値 

342 """ 

343 return self.artifacts.get(key, default) 

344 

345 class Config: 

346 # Allow arbitrary types for flexibility 

347 # 柔軟性のために任意の型を許可 

348 arbitrary_types_allowed = True