Coverage for src/refinire/agents/flow/context.py: 95%

155 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-15 18:51 +0900

1from __future__ import annotations 

2 

3"""Context — Shared state management for Flow/Step workflows. 

4 

5Contextはフロー/ステップワークフロー用の共有状態管理を提供します。 

6型安全で読みやすく、LangChain LCELとの互換性も持ちます。 

7""" 

8 

9import asyncio 

10from typing import Any, Dict, List, Optional, Union 

11from datetime import datetime 

12 

13try: 

14 from pydantic import BaseModel, Field, PrivateAttr # type: ignore 

15except ImportError: 

16 BaseModel = object # type: ignore 

17 Field = lambda **kwargs: None # type: ignore 

18 PrivateAttr = lambda **kwargs: None # type: ignore 

19 

20 

21class Message(BaseModel): 

22 """ 

23 Message class for conversation history 

24 会話履歴用メッセージクラス 

25  

26 Attributes: 

27 role: Message role (user, assistant, system) / メッセージの役割 

28 content: Message content / メッセージ内容 

29 timestamp: Message timestamp / メッセージのタイムスタンプ 

30 metadata: Additional metadata / 追加メタデータ 

31 """ 

32 role: str # Message role (user, assistant, system) / メッセージの役割 

33 content: str # Message content / メッセージ内容 

34 timestamp: datetime = Field(default_factory=datetime.now) # Message timestamp / メッセージのタイムスタンプ 

35 metadata: Dict[str, Any] = Field(default_factory=dict) # Additional metadata / 追加メタデータ 

36 

37 

38class Context(BaseModel): 

39 """ 

40 Context class for Flow/Step workflow state management 

41 フロー/ステップワークフロー状態管理用コンテキストクラス 

42  

43 This class provides: 

44 このクラスは以下を提供します: 

45 - Type-safe shared state / 型安全な共有状態 

46 - Conversation history management / 会話履歴管理 

47 - Step routing control / ステップルーティング制御 

48 - LangChain LCEL compatibility / LangChain LCEL互換性 

49 - User input/output coordination / ユーザー入出力調整 

50 """ 

51 

52 # Core state / コア状態 

53 last_user_input: Optional[str] = None # Most recent user input / 直近のユーザー入力 

54 messages: List[Message] = Field(default_factory=list) # Conversation history / 会話履歴 

55 

56 # External data / 外部データ 

57 knowledge: Dict[str, Any] = Field(default_factory=dict) # External knowledge (RAG, etc.) / 外部知識(RAGなど) 

58 prev_outputs: Dict[str, Any] = Field(default_factory=dict) # Previous step outputs / 前ステップの出力 

59 

60 # Flow control / フロー制御 

61 next_label: Optional[str] = None # Next step routing instruction / 次ステップのルーティング指示 

62 current_step: Optional[str] = None # Current step name / 現在のステップ名 

63 

64 # Results / 結果 

65 artifacts: Dict[str, Any] = Field(default_factory=dict) # Flow-wide artifacts / フロー全体の成果物 

66 shared_state: Dict[str, Any] = Field(default_factory=dict) # Arbitrary shared values / 任意の共有値 

67 

68 # User interaction / ユーザー対話 

69 awaiting_prompt: Optional[str] = None # Prompt waiting for user input / ユーザー入力待ちのプロンプト 

70 awaiting_user_input: bool = False # Flag indicating waiting for user input / ユーザー入力待ちフラグ 

71 

72 # Execution metadata / 実行メタデータ 

73 trace_id: Optional[str] = None # Trace ID for observability / オブザーバビリティ用トレースID 

74 current_span_id: Optional[str] = None # Current span ID for step tracking / ステップ追跡用現在のスパンID 

75 start_time: datetime = Field(default_factory=datetime.now) # Flow start time / フロー開始時刻 

76 step_count: int = 0 # Number of steps executed / 実行されたステップ数 

77 span_history: List[Dict[str, Any]] = Field(default_factory=list) # Span execution history / スパン実行履歴 

78 

79 # Internal async coordination (private attributes) / 内部非同期調整(プライベート属性) 

80 _user_input_event: Optional[asyncio.Event] = PrivateAttr(default=None) 

81 _awaiting_prompt_event: Optional[asyncio.Event] = PrivateAttr(default=None) 

82 

83 def __init__(self, **data): 

84 """ 

85 Initialize Context with async events 

86 非同期イベントでContextを初期化 

87 """ 

88 super().__init__(**data) 

89 self._user_input_event = asyncio.Event() 

90 self._awaiting_prompt_event = asyncio.Event() 

91 

92 def add_user_message(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> None: 

93 """ 

94 Add user message to conversation history 

95 ユーザーメッセージを会話履歴に追加 

96  

97 Args: 

98 content: Message content / メッセージ内容 

99 metadata: Additional metadata / 追加メタデータ 

100 """ 

101 message = Message( 

102 role="user", 

103 content=content, 

104 metadata=metadata or {} 

105 ) 

106 self.messages.append(message) 

107 self.last_user_input = content 

108 

109 def add_assistant_message(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> None: 

110 """ 

111 Add assistant message to conversation history 

112 アシスタントメッセージを会話履歴に追加 

113  

114 Args: 

115 content: Message content / メッセージ内容 

116 metadata: Additional metadata / 追加メタデータ 

117 """ 

118 message = Message( 

119 role="assistant", 

120 content=content, 

121 metadata=metadata or {} 

122 ) 

123 self.messages.append(message) 

124 

125 def add_system_message(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> None: 

126 """ 

127 Add system message to conversation history 

128 システムメッセージを会話履歴に追加 

129  

130 Args: 

131 content: Message content / メッセージ内容 

132 metadata: Additional metadata / 追加メタデータ 

133 """ 

134 message = Message( 

135 role="system", 

136 content=content, 

137 metadata=metadata or {} 

138 ) 

139 self.messages.append(message) 

140 

141 def set_waiting_for_user_input(self, prompt: str) -> None: 

142 """ 

143 Set context to wait for user input with a prompt 

144 プロンプトでユーザー入力待ち状態に設定 

145  

146 Args: 

147 prompt: Prompt to display to user / ユーザーに表示するプロンプト 

148 """ 

149 self.awaiting_prompt = prompt 

150 self.awaiting_user_input = True 

151 if self._awaiting_prompt_event: 

152 self._awaiting_prompt_event.set() 

153 

154 def provide_user_input(self, user_input: str) -> None: 

155 """ 

156 Provide user input and clear waiting state 

157 ユーザー入力を提供し、待ち状態をクリア 

158  

159 Args: 

160 user_input: User input text / ユーザー入力テキスト 

161 """ 

162 self.add_user_message(user_input) 

163 self.awaiting_prompt = None 

164 self.awaiting_user_input = False 

165 if self._user_input_event: 

166 self._user_input_event.set() 

167 

168 def clear_prompt(self) -> Optional[str]: 

169 """ 

170 Clear and return the current prompt 

171 現在のプロンプトをクリアして返す 

172  

173 Returns: 

174 str | None: The prompt if one was waiting / 待機中だったプロンプト 

175 """ 

176 prompt = self.awaiting_prompt 

177 self.awaiting_prompt = None 

178 if self._awaiting_prompt_event: 

179 self._awaiting_prompt_event.clear() 

180 return prompt 

181 

182 async def wait_for_user_input(self) -> str: 

183 """ 

184 Async wait for user input 

185 ユーザー入力を非同期で待機 

186  

187 Returns: 

188 str: User input / ユーザー入力 

189 """ 

190 if self._user_input_event: 

191 await self._user_input_event.wait() 

192 self._user_input_event.clear() 

193 return self.last_user_input or "" 

194 

195 async def wait_for_prompt_event(self) -> str: 

196 """ 

197 Async wait for prompt event 

198 プロンプトイベントを非同期で待機 

199  

200 Returns: 

201 str: Prompt waiting for user / ユーザー待ちのプロンプト 

202 """ 

203 if self._awaiting_prompt_event: 

204 await self._awaiting_prompt_event.wait() 

205 return self.awaiting_prompt or "" 

206 

207 def goto(self, label: str) -> None: 

208 """ 

209 Set next step routing 

210 次ステップのルーティングを設定 

211  

212 Args: 

213 label: Next step label / 次ステップのラベル 

214 """ 

215 self.next_label = label 

216 

217 def finish(self) -> None: 

218 """ 

219 Mark flow as finished 

220 フローを完了としてマーク 

221 """ 

222 self.next_label = None 

223 

224 def is_finished(self) -> bool: 

225 """ 

226 Check if flow is finished 

227 フローが完了しているかチェック 

228  

229 Returns: 

230 bool: True if finished / 完了している場合True 

231 """ 

232 return self.next_label is None 

233 

234 @property 

235 def finished(self) -> bool: 

236 """ 

237 Property to check if flow is finished 

238 フローが完了しているかチェックするプロパティ 

239  

240 Returns: 

241 bool: True if finished / 完了している場合True 

242 """ 

243 return self.is_finished() 

244 

245 def as_dict(self) -> Dict[str, Any]: 

246 """ 

247 Convert to dictionary for LangChain LCEL compatibility 

248 LangChain LCEL互換性のために辞書に変換 

249  

250 Returns: 

251 Dict[str, Any]: Dictionary representation / 辞書表現 

252 """ 

253 data = self.model_dump() 

254 # Convert messages to LangChain format 

255 # メッセージをLangChain形式に変換 

256 data["history"] = [ 

257 {"role": msg.role, "content": msg.content, "metadata": msg.metadata} 

258 for msg in self.messages 

259 ] 

260 data.pop("messages", None) 

261 return data 

262 

263 @classmethod 

264 def from_dict(cls, data: Dict[str, Any]) -> "Context": 

265 """ 

266 Create Context from dictionary (LangChain LCEL compatibility) 

267 辞書からContextを作成(LangChain LCEL互換性) 

268  

269 Args: 

270 data: Dictionary data / 辞書データ 

271  

272 Returns: 

273 Context: New context instance / 新しいコンテキストインスタンス 

274 """ 

275 data = data.copy() 

276 # Convert history to messages 

277 # 履歴をメッセージに変換 

278 history = data.pop("history", []) 

279 messages = [] 

280 for msg_data in history: 

281 if isinstance(msg_data, dict): 

282 messages.append(Message( 

283 role=msg_data.get("role", "user"), 

284 content=msg_data.get("content", ""), 

285 metadata=msg_data.get("metadata", {}) 

286 )) 

287 data["messages"] = messages 

288 return cls(**data) 

289 

290 def get_conversation_text(self, include_system: bool = False) -> str: 

291 """ 

292 Get conversation as formatted text 

293 会話をフォーマット済みテキストとして取得 

294  

295 Args: 

296 include_system: Include system messages / システムメッセージを含める 

297  

298 Returns: 

299 str: Formatted conversation / フォーマット済み会話 

300 """ 

301 lines = [] 

302 for msg in self.messages: 

303 if not include_system and msg.role == "system": 

304 continue 

305 role_label = {"user": "👤", "assistant": "🤖", "system": "⚙️"}.get(msg.role, msg.role) 

306 lines.append(f"{role_label} {msg.content}") 

307 return "\n".join(lines) 

308 

309 def get_last_messages(self, n: int = 10) -> List[Message]: 

310 """ 

311 Get last N messages 

312 最後のNメッセージを取得 

313  

314 Args: 

315 n: Number of messages / メッセージ数 

316  

317 Returns: 

318 List[Message]: Last N messages / 最後のNメッセージ 

319 """ 

320 return self.messages[-n:] if len(self.messages) > n else self.messages.copy() 

321 

322 def update_step_info(self, step_name: str) -> None: 

323 """ 

324 Update current step information 

325 現在のステップ情報を更新 

326  

327 Args: 

328 step_name: Current step name / 現在のステップ名 

329 """ 

330 # Finalize previous span if exists 

331 # 前のスパンが存在する場合は終了 

332 if self.current_span_id and self.current_step: 

333 self._finalize_current_span() 

334 

335 # Start new span 

336 # 新しいスパンを開始 

337 self.current_step = step_name 

338 self.step_count += 1 

339 self.current_span_id = self._generate_span_id(step_name) 

340 

341 # Record span start 

342 # スパン開始を記録 

343 self._start_span(step_name) 

344 

345 def _generate_span_id(self, step_name: str) -> str: 

346 """ 

347 Generate a unique span ID for the step 

348 ステップ用のユニークなスパンIDを生成 

349  

350 Args: 

351 step_name: Step name / ステップ名 

352  

353 Returns: 

354 str: Generated span ID / 生成されたスパンID 

355 """ 

356 timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f') 

357 return f"{step_name}_{self.step_count:03d}_{timestamp}" 

358 

359 def _start_span(self, step_name: str) -> None: 

360 """ 

361 Start a new span for step tracking 

362 ステップ追跡用の新しいスパンを開始 

363  

364 Args: 

365 step_name: Step name / ステップ名 

366 """ 

367 span_info = { 

368 "span_id": self.current_span_id, 

369 "step_name": step_name, 

370 "trace_id": self.trace_id, 

371 "start_time": datetime.now(), 

372 "end_time": None, 

373 "status": "started", 

374 "step_index": self.step_count, 

375 "metadata": {} 

376 } 

377 self.span_history.append(span_info) 

378 

379 def _finalize_current_span(self, status: str = "completed", error: Optional[str] = None) -> None: 

380 """ 

381 Finalize the current span 

382 現在のスパンを終了 

383  

384 Args: 

385 status: Span status (completed, error, etc.) / スパンステータス 

386 error: Error message if failed / 失敗時のエラーメッセージ 

387 """ 

388 if not self.span_history: 

389 return 

390 

391 current_span = self.span_history[-1] 

392 if current_span["span_id"] == self.current_span_id: 

393 current_span["end_time"] = datetime.now() 

394 current_span["status"] = status 

395 if error: 

396 current_span["error"] = error 

397 

398 def finalize_flow_span(self) -> None: 

399 """ 

400 Finalize the current span when flow ends 

401 フロー終了時に現在のスパンを終了 

402 """ 

403 if self.current_span_id: 

404 self._finalize_current_span() 

405 self.current_span_id = None 

406 

407 def set_artifact(self, key: str, value: Any) -> None: 

408 """ 

409 Set artifact value 

410 成果物の値を設定 

411  

412 Args: 

413 key: Artifact key / 成果物キー 

414 value: Artifact value / 成果物値 

415 """ 

416 self.artifacts[key] = value 

417 

418 def get_artifact(self, key: str, default: Any = None) -> Any: 

419 """ 

420 Get artifact value 

421 成果物の値を取得 

422  

423 Args: 

424 key: Artifact key / 成果物キー 

425 default: Default value if not found / 見つからない場合のデフォルト値 

426  

427 Returns: 

428 Any: Artifact value / 成果物値 

429 """ 

430 return self.artifacts.get(key, default) 

431 

432 def get_current_span_info(self) -> Optional[Dict[str, Any]]: 

433 """ 

434 Get current span information 

435 現在のスパン情報を取得 

436  

437 Returns: 

438 Dict[str, Any] | None: Current span info / 現在のスパン情報 

439 """ 

440 if self.current_span_id and self.span_history: 

441 return self.span_history[-1] 

442 return None 

443 

444 def get_span_history(self) -> List[Dict[str, Any]]: 

445 """ 

446 Get complete span execution history 

447 完全なスパン実行履歴を取得 

448  

449 Returns: 

450 List[Dict[str, Any]]: Span history / スパン履歴 

451 """ 

452 return self.span_history.copy() 

453 

454 def get_trace_summary(self) -> Dict[str, Any]: 

455 """ 

456 Get comprehensive trace summary 

457 包括的なトレースサマリーを取得 

458  

459 Returns: 

460 Dict[str, Any]: Trace summary / トレースサマリー 

461 """ 

462 total_duration = None 

463 if self.span_history: 

464 start_time = min(span["start_time"] for span in self.span_history) 

465 completed_spans = [span for span in self.span_history if span.get("end_time")] 

466 if completed_spans: 

467 end_time = max(span["end_time"] for span in completed_spans) 

468 total_duration = (end_time - start_time).total_seconds() 

469 

470 return { 

471 "trace_id": self.trace_id, 

472 "current_span_id": self.current_span_id, 

473 "total_spans": len(self.span_history), 

474 "completed_spans": len([s for s in self.span_history if s.get("status") == "completed"]), 

475 "active_spans": len([s for s in self.span_history if s.get("status") == "started"]), 

476 "error_spans": len([s for s in self.span_history if s.get("status") == "error"]), 

477 "total_duration_seconds": total_duration, 

478 "flow_start_time": self.start_time, 

479 "is_finished": self.is_finished() 

480 } 

481 

482 class Config: 

483 # Allow arbitrary types for flexibility 

484 # 柔軟性のために任意の型を許可 

485 arbitrary_types_allowed = True