Coverage for src/refinire/agents/flow/context.py: 95%
155 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-15 18:51 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-15 18:51 +0900
1from __future__ import annotations
3"""Context — Shared state management for Flow/Step workflows.
5Contextはフロー/ステップワークフロー用の共有状態管理を提供します。
6型安全で読みやすく、LangChain LCELとの互換性も持ちます。
7"""
9import asyncio
10from typing import Any, Dict, List, Optional, Union
11from datetime import datetime
13try:
14 from pydantic import BaseModel, Field, PrivateAttr # type: ignore
15except ImportError:
16 BaseModel = object # type: ignore
17 Field = lambda **kwargs: None # type: ignore
18 PrivateAttr = lambda **kwargs: None # type: ignore
21class Message(BaseModel):
22 """
23 Message class for conversation history
24 会話履歴用メッセージクラス
26 Attributes:
27 role: Message role (user, assistant, system) / メッセージの役割
28 content: Message content / メッセージ内容
29 timestamp: Message timestamp / メッセージのタイムスタンプ
30 metadata: Additional metadata / 追加メタデータ
31 """
32 role: str # Message role (user, assistant, system) / メッセージの役割
33 content: str # Message content / メッセージ内容
34 timestamp: datetime = Field(default_factory=datetime.now) # Message timestamp / メッセージのタイムスタンプ
35 metadata: Dict[str, Any] = Field(default_factory=dict) # Additional metadata / 追加メタデータ
38class Context(BaseModel):
39 """
40 Context class for Flow/Step workflow state management
41 フロー/ステップワークフロー状態管理用コンテキストクラス
43 This class provides:
44 このクラスは以下を提供します:
45 - Type-safe shared state / 型安全な共有状態
46 - Conversation history management / 会話履歴管理
47 - Step routing control / ステップルーティング制御
48 - LangChain LCEL compatibility / LangChain LCEL互換性
49 - User input/output coordination / ユーザー入出力調整
50 """
52 # Core state / コア状態
53 last_user_input: Optional[str] = None # Most recent user input / 直近のユーザー入力
54 messages: List[Message] = Field(default_factory=list) # Conversation history / 会話履歴
56 # External data / 外部データ
57 knowledge: Dict[str, Any] = Field(default_factory=dict) # External knowledge (RAG, etc.) / 外部知識(RAGなど)
58 prev_outputs: Dict[str, Any] = Field(default_factory=dict) # Previous step outputs / 前ステップの出力
60 # Flow control / フロー制御
61 next_label: Optional[str] = None # Next step routing instruction / 次ステップのルーティング指示
62 current_step: Optional[str] = None # Current step name / 現在のステップ名
64 # Results / 結果
65 artifacts: Dict[str, Any] = Field(default_factory=dict) # Flow-wide artifacts / フロー全体の成果物
66 shared_state: Dict[str, Any] = Field(default_factory=dict) # Arbitrary shared values / 任意の共有値
68 # User interaction / ユーザー対話
69 awaiting_prompt: Optional[str] = None # Prompt waiting for user input / ユーザー入力待ちのプロンプト
70 awaiting_user_input: bool = False # Flag indicating waiting for user input / ユーザー入力待ちフラグ
72 # Execution metadata / 実行メタデータ
73 trace_id: Optional[str] = None # Trace ID for observability / オブザーバビリティ用トレースID
74 current_span_id: Optional[str] = None # Current span ID for step tracking / ステップ追跡用現在のスパンID
75 start_time: datetime = Field(default_factory=datetime.now) # Flow start time / フロー開始時刻
76 step_count: int = 0 # Number of steps executed / 実行されたステップ数
77 span_history: List[Dict[str, Any]] = Field(default_factory=list) # Span execution history / スパン実行履歴
79 # Internal async coordination (private attributes) / 内部非同期調整(プライベート属性)
80 _user_input_event: Optional[asyncio.Event] = PrivateAttr(default=None)
81 _awaiting_prompt_event: Optional[asyncio.Event] = PrivateAttr(default=None)
83 def __init__(self, **data):
84 """
85 Initialize Context with async events
86 非同期イベントでContextを初期化
87 """
88 super().__init__(**data)
89 self._user_input_event = asyncio.Event()
90 self._awaiting_prompt_event = asyncio.Event()
92 def add_user_message(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> None:
93 """
94 Add user message to conversation history
95 ユーザーメッセージを会話履歴に追加
97 Args:
98 content: Message content / メッセージ内容
99 metadata: Additional metadata / 追加メタデータ
100 """
101 message = Message(
102 role="user",
103 content=content,
104 metadata=metadata or {}
105 )
106 self.messages.append(message)
107 self.last_user_input = content
109 def add_assistant_message(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> None:
110 """
111 Add assistant message to conversation history
112 アシスタントメッセージを会話履歴に追加
114 Args:
115 content: Message content / メッセージ内容
116 metadata: Additional metadata / 追加メタデータ
117 """
118 message = Message(
119 role="assistant",
120 content=content,
121 metadata=metadata or {}
122 )
123 self.messages.append(message)
125 def add_system_message(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> None:
126 """
127 Add system message to conversation history
128 システムメッセージを会話履歴に追加
130 Args:
131 content: Message content / メッセージ内容
132 metadata: Additional metadata / 追加メタデータ
133 """
134 message = Message(
135 role="system",
136 content=content,
137 metadata=metadata or {}
138 )
139 self.messages.append(message)
141 def set_waiting_for_user_input(self, prompt: str) -> None:
142 """
143 Set context to wait for user input with a prompt
144 プロンプトでユーザー入力待ち状態に設定
146 Args:
147 prompt: Prompt to display to user / ユーザーに表示するプロンプト
148 """
149 self.awaiting_prompt = prompt
150 self.awaiting_user_input = True
151 if self._awaiting_prompt_event:
152 self._awaiting_prompt_event.set()
154 def provide_user_input(self, user_input: str) -> None:
155 """
156 Provide user input and clear waiting state
157 ユーザー入力を提供し、待ち状態をクリア
159 Args:
160 user_input: User input text / ユーザー入力テキスト
161 """
162 self.add_user_message(user_input)
163 self.awaiting_prompt = None
164 self.awaiting_user_input = False
165 if self._user_input_event:
166 self._user_input_event.set()
168 def clear_prompt(self) -> Optional[str]:
169 """
170 Clear and return the current prompt
171 現在のプロンプトをクリアして返す
173 Returns:
174 str | None: The prompt if one was waiting / 待機中だったプロンプト
175 """
176 prompt = self.awaiting_prompt
177 self.awaiting_prompt = None
178 if self._awaiting_prompt_event:
179 self._awaiting_prompt_event.clear()
180 return prompt
182 async def wait_for_user_input(self) -> str:
183 """
184 Async wait for user input
185 ユーザー入力を非同期で待機
187 Returns:
188 str: User input / ユーザー入力
189 """
190 if self._user_input_event:
191 await self._user_input_event.wait()
192 self._user_input_event.clear()
193 return self.last_user_input or ""
195 async def wait_for_prompt_event(self) -> str:
196 """
197 Async wait for prompt event
198 プロンプトイベントを非同期で待機
200 Returns:
201 str: Prompt waiting for user / ユーザー待ちのプロンプト
202 """
203 if self._awaiting_prompt_event:
204 await self._awaiting_prompt_event.wait()
205 return self.awaiting_prompt or ""
207 def goto(self, label: str) -> None:
208 """
209 Set next step routing
210 次ステップのルーティングを設定
212 Args:
213 label: Next step label / 次ステップのラベル
214 """
215 self.next_label = label
217 def finish(self) -> None:
218 """
219 Mark flow as finished
220 フローを完了としてマーク
221 """
222 self.next_label = None
224 def is_finished(self) -> bool:
225 """
226 Check if flow is finished
227 フローが完了しているかチェック
229 Returns:
230 bool: True if finished / 完了している場合True
231 """
232 return self.next_label is None
234 @property
235 def finished(self) -> bool:
236 """
237 Property to check if flow is finished
238 フローが完了しているかチェックするプロパティ
240 Returns:
241 bool: True if finished / 完了している場合True
242 """
243 return self.is_finished()
245 def as_dict(self) -> Dict[str, Any]:
246 """
247 Convert to dictionary for LangChain LCEL compatibility
248 LangChain LCEL互換性のために辞書に変換
250 Returns:
251 Dict[str, Any]: Dictionary representation / 辞書表現
252 """
253 data = self.model_dump()
254 # Convert messages to LangChain format
255 # メッセージをLangChain形式に変換
256 data["history"] = [
257 {"role": msg.role, "content": msg.content, "metadata": msg.metadata}
258 for msg in self.messages
259 ]
260 data.pop("messages", None)
261 return data
263 @classmethod
264 def from_dict(cls, data: Dict[str, Any]) -> "Context":
265 """
266 Create Context from dictionary (LangChain LCEL compatibility)
267 辞書からContextを作成(LangChain LCEL互換性)
269 Args:
270 data: Dictionary data / 辞書データ
272 Returns:
273 Context: New context instance / 新しいコンテキストインスタンス
274 """
275 data = data.copy()
276 # Convert history to messages
277 # 履歴をメッセージに変換
278 history = data.pop("history", [])
279 messages = []
280 for msg_data in history:
281 if isinstance(msg_data, dict):
282 messages.append(Message(
283 role=msg_data.get("role", "user"),
284 content=msg_data.get("content", ""),
285 metadata=msg_data.get("metadata", {})
286 ))
287 data["messages"] = messages
288 return cls(**data)
290 def get_conversation_text(self, include_system: bool = False) -> str:
291 """
292 Get conversation as formatted text
293 会話をフォーマット済みテキストとして取得
295 Args:
296 include_system: Include system messages / システムメッセージを含める
298 Returns:
299 str: Formatted conversation / フォーマット済み会話
300 """
301 lines = []
302 for msg in self.messages:
303 if not include_system and msg.role == "system":
304 continue
305 role_label = {"user": "👤", "assistant": "🤖", "system": "⚙️"}.get(msg.role, msg.role)
306 lines.append(f"{role_label} {msg.content}")
307 return "\n".join(lines)
309 def get_last_messages(self, n: int = 10) -> List[Message]:
310 """
311 Get last N messages
312 最後のNメッセージを取得
314 Args:
315 n: Number of messages / メッセージ数
317 Returns:
318 List[Message]: Last N messages / 最後のNメッセージ
319 """
320 return self.messages[-n:] if len(self.messages) > n else self.messages.copy()
322 def update_step_info(self, step_name: str) -> None:
323 """
324 Update current step information
325 現在のステップ情報を更新
327 Args:
328 step_name: Current step name / 現在のステップ名
329 """
330 # Finalize previous span if exists
331 # 前のスパンが存在する場合は終了
332 if self.current_span_id and self.current_step:
333 self._finalize_current_span()
335 # Start new span
336 # 新しいスパンを開始
337 self.current_step = step_name
338 self.step_count += 1
339 self.current_span_id = self._generate_span_id(step_name)
341 # Record span start
342 # スパン開始を記録
343 self._start_span(step_name)
345 def _generate_span_id(self, step_name: str) -> str:
346 """
347 Generate a unique span ID for the step
348 ステップ用のユニークなスパンIDを生成
350 Args:
351 step_name: Step name / ステップ名
353 Returns:
354 str: Generated span ID / 生成されたスパンID
355 """
356 timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
357 return f"{step_name}_{self.step_count:03d}_{timestamp}"
359 def _start_span(self, step_name: str) -> None:
360 """
361 Start a new span for step tracking
362 ステップ追跡用の新しいスパンを開始
364 Args:
365 step_name: Step name / ステップ名
366 """
367 span_info = {
368 "span_id": self.current_span_id,
369 "step_name": step_name,
370 "trace_id": self.trace_id,
371 "start_time": datetime.now(),
372 "end_time": None,
373 "status": "started",
374 "step_index": self.step_count,
375 "metadata": {}
376 }
377 self.span_history.append(span_info)
379 def _finalize_current_span(self, status: str = "completed", error: Optional[str] = None) -> None:
380 """
381 Finalize the current span
382 現在のスパンを終了
384 Args:
385 status: Span status (completed, error, etc.) / スパンステータス
386 error: Error message if failed / 失敗時のエラーメッセージ
387 """
388 if not self.span_history:
389 return
391 current_span = self.span_history[-1]
392 if current_span["span_id"] == self.current_span_id:
393 current_span["end_time"] = datetime.now()
394 current_span["status"] = status
395 if error:
396 current_span["error"] = error
398 def finalize_flow_span(self) -> None:
399 """
400 Finalize the current span when flow ends
401 フロー終了時に現在のスパンを終了
402 """
403 if self.current_span_id:
404 self._finalize_current_span()
405 self.current_span_id = None
407 def set_artifact(self, key: str, value: Any) -> None:
408 """
409 Set artifact value
410 成果物の値を設定
412 Args:
413 key: Artifact key / 成果物キー
414 value: Artifact value / 成果物値
415 """
416 self.artifacts[key] = value
418 def get_artifact(self, key: str, default: Any = None) -> Any:
419 """
420 Get artifact value
421 成果物の値を取得
423 Args:
424 key: Artifact key / 成果物キー
425 default: Default value if not found / 見つからない場合のデフォルト値
427 Returns:
428 Any: Artifact value / 成果物値
429 """
430 return self.artifacts.get(key, default)
432 def get_current_span_info(self) -> Optional[Dict[str, Any]]:
433 """
434 Get current span information
435 現在のスパン情報を取得
437 Returns:
438 Dict[str, Any] | None: Current span info / 現在のスパン情報
439 """
440 if self.current_span_id and self.span_history:
441 return self.span_history[-1]
442 return None
444 def get_span_history(self) -> List[Dict[str, Any]]:
445 """
446 Get complete span execution history
447 完全なスパン実行履歴を取得
449 Returns:
450 List[Dict[str, Any]]: Span history / スパン履歴
451 """
452 return self.span_history.copy()
454 def get_trace_summary(self) -> Dict[str, Any]:
455 """
456 Get comprehensive trace summary
457 包括的なトレースサマリーを取得
459 Returns:
460 Dict[str, Any]: Trace summary / トレースサマリー
461 """
462 total_duration = None
463 if self.span_history:
464 start_time = min(span["start_time"] for span in self.span_history)
465 completed_spans = [span for span in self.span_history if span.get("end_time")]
466 if completed_spans:
467 end_time = max(span["end_time"] for span in completed_spans)
468 total_duration = (end_time - start_time).total_seconds()
470 return {
471 "trace_id": self.trace_id,
472 "current_span_id": self.current_span_id,
473 "total_spans": len(self.span_history),
474 "completed_spans": len([s for s in self.span_history if s.get("status") == "completed"]),
475 "active_spans": len([s for s in self.span_history if s.get("status") == "started"]),
476 "error_spans": len([s for s in self.span_history if s.get("status") == "error"]),
477 "total_duration_seconds": total_duration,
478 "flow_start_time": self.start_time,
479 "is_finished": self.is_finished()
480 }
482 class Config:
483 # Allow arbitrary types for flexibility
484 # 柔軟性のために任意の型を許可
485 arbitrary_types_allowed = True