Coverage for src/refinire/agents/flow/step.py: 83%
191 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 15:27 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 15:27 +0900
1from __future__ import annotations
3"""Step — Step interface and basic implementations for Flow workflows.
5Stepはフローワークフロー用のステップインターフェースと基本実装を提供します。
6UserInputStep、ConditionStep、ForkStep、JoinStepなどの基本的なステップを含みます。
7"""
9import asyncio
10from abc import ABC, abstractmethod
11from typing import Any, Callable, Dict, List, Optional, Union, Awaitable
12from concurrent.futures import ThreadPoolExecutor
13import threading
15from .context import Context
18class Step(ABC):
19 """
20 Abstract base class for workflow steps
21 ワークフローステップの抽象基底クラス
23 All step implementations must provide:
24 全てのステップ実装は以下を提供する必要があります:
25 - name: Step identifier for DSL reference / DSL参照用ステップ識別子
26 - run: Async execution method / 非同期実行メソッド
27 """
29 def __init__(self, name: str):
30 """
31 Initialize step with name
32 名前でステップを初期化
34 Args:
35 name: Step name / ステップ名
36 """
37 self.name = name
39 @abstractmethod
40 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
41 """
42 Execute step and return updated context
43 ステップを実行し、更新されたコンテキストを返す
45 Args:
46 user_input: User input if any / ユーザー入力(あれば)
47 ctx: Current context / 現在のコンテキスト
49 Returns:
50 Context: Updated context with next_label set / next_labelが設定された更新済みコンテキスト
51 """
52 pass
54 def __str__(self) -> str:
55 return f"{self.__class__.__name__}({self.name})"
57 def __repr__(self) -> str:
58 return self.__str__()
61class UserInputStep(Step):
62 """
63 Step that waits for user input
64 ユーザー入力を待機するステップ
66 This step displays a prompt and waits for user response.
67 このステップはプロンプトを表示し、ユーザー応答を待機します。
68 It sets the context to waiting state and returns without advancing.
69 コンテキストを待機状態に設定し、進行せずに返します。
70 """
72 def __init__(self, name: str, prompt: str, next_step: Optional[str] = None):
73 """
74 Initialize user input step
75 ユーザー入力ステップを初期化
77 Args:
78 name: Step name / ステップ名
79 prompt: Prompt to display to user / ユーザーに表示するプロンプト
80 next_step: Next step after input (optional) / 入力後の次ステップ(オプション)
81 """
82 super().__init__(name)
83 self.prompt = prompt
84 self.next_step = next_step
86 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
87 """
88 Execute user input step
89 ユーザー入力ステップを実行
91 Args:
92 user_input: User input if available / 利用可能なユーザー入力
93 ctx: Current context / 現在のコンテキスト
95 Returns:
96 Context: Updated context / 更新済みコンテキスト
97 """
98 ctx.update_step_info(self.name)
100 # If user input is provided, process it
101 # ユーザー入力が提供されている場合、処理する
102 if user_input is not None:
103 ctx.provide_user_input(user_input)
104 if self.next_step:
105 ctx.goto(self.next_step)
106 else:
107 ctx.finish()
108 # Note: If next_step is None, flow will end
109 # 注:next_stepがNoneの場合、フローは終了
110 else:
111 # Set waiting state for user input
112 # ユーザー入力の待機状態を設定
113 ctx.set_waiting_for_user_input(self.prompt)
115 return ctx
118class ConditionStep(Step):
119 """
120 Step that performs conditional routing
121 条件付きルーティングを実行するステップ
123 This step evaluates a condition and routes to different steps based on the result.
124 このステップは条件を評価し、結果に基づいて異なるステップにルーティングします。
125 """
127 def __init__(
128 self,
129 name: str,
130 condition: Callable[[Context], Union[bool, Awaitable[bool]]],
131 if_true: str,
132 if_false: str
133 ):
134 """
135 Initialize condition step
136 条件ステップを初期化
138 Args:
139 name: Step name / ステップ名
140 condition: Condition function / 条件関数
141 if_true: Step to go if condition is True / 条件がTrueの場合のステップ
142 if_false: Step to go if condition is False / 条件がFalseの場合のステップ
143 """
144 super().__init__(name)
145 self.condition = condition
146 self.if_true = if_true
147 self.if_false = if_false
149 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
150 """
151 Execute condition step
152 条件ステップを実行
154 Args:
155 user_input: User input (not used) / ユーザー入力(使用されない)
156 ctx: Current context / 現在のコンテキスト
158 Returns:
159 Context: Updated context with routing / ルーティング付き更新済みコンテキスト
160 """
161 ctx.update_step_info(self.name)
163 # Evaluate condition (may be async)
164 # 条件を評価(非同期の可能性あり)
165 try:
166 result = self.condition(ctx)
167 if asyncio.iscoroutine(result):
168 result = await result
169 except Exception as e:
170 # On error, go to false branch
171 # エラー時はfalseブランチに進む
172 ctx.add_system_message(f"Condition evaluation error: {e}")
173 result = False
175 # Route based on condition result
176 # 条件結果に基づいてルーティング
177 next_step = self.if_true if result else self.if_false
178 ctx.goto(next_step)
180 return ctx
183class FunctionStep(Step):
184 """
185 Step that executes a custom function
186 カスタム関数を実行するステップ
188 This step allows executing arbitrary code within the workflow.
189 このステップはワークフロー内で任意のコードを実行できます。
190 """
192 def __init__(
193 self,
194 name: str,
195 function: Callable[[Optional[str], Context], Union[Context, Awaitable[Context]]],
196 next_step: Optional[str] = None
197 ):
198 """
199 Initialize function step
200 関数ステップを初期化
202 Args:
203 name: Step name / ステップ名
204 function: Function to execute / 実行する関数
205 next_step: Next step after execution / 実行後の次ステップ
206 """
207 super().__init__(name)
208 self.function = function
209 self.next_step = next_step
211 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
212 """
213 Execute function step
214 関数ステップを実行
216 Args:
217 user_input: User input / ユーザー入力
218 ctx: Current context / 現在のコンテキスト
220 Returns:
221 Context: Updated context / 更新済みコンテキスト
222 """
223 ctx.update_step_info(self.name)
225 try:
226 # Execute the function (may be async)
227 # 関数を実行(非同期の可能性あり)
228 result = self.function(user_input, ctx)
229 if asyncio.iscoroutine(result):
230 ctx = await result
231 else:
232 ctx = result
233 except Exception as e:
234 ctx.add_system_message(f"Function execution error in {self.name}: {e}")
236 # Set next step if specified, otherwise finish the flow
237 # 指定されている場合は次ステップを設定、そうでなければフローを終了
238 if self.next_step:
239 ctx.goto(self.next_step)
240 else:
241 ctx.finish()
243 return ctx
246class ForkStep(Step):
247 """
248 Step that executes multiple branches in parallel
249 複数のブランチを並列実行するステップ
251 This step starts multiple sub-flows concurrently and collects their results.
252 このステップは複数のサブフローを同時に開始し、結果を収集します。
253 """
255 def __init__(self, name: str, branches: List[str], join_step: str):
256 """
257 Initialize fork step
258 フォークステップを初期化
260 Args:
261 name: Step name / ステップ名
262 branches: List of branch step names to execute in parallel / 並列実行するブランチステップ名のリスト
263 join_step: Step to join results / 結果を結合するステップ
264 """
265 super().__init__(name)
266 self.branches = branches
267 self.join_step = join_step
269 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
270 """
271 Execute fork step
272 フォークステップを実行
274 Args:
275 user_input: User input / ユーザー入力
276 ctx: Current context / 現在のコンテキスト
278 Returns:
279 Context: Updated context / 更新済みコンテキスト
280 """
281 ctx.update_step_info(self.name)
283 # Store branch information for join step
284 # ジョインステップ用にブランチ情報を保存
285 ctx.shared_state[f"{self.name}_branches"] = self.branches
286 ctx.shared_state[f"{self.name}_started"] = True
288 # For now, just route to the join step
289 # 現在のところ、ジョインステップにルーティングするだけ
290 # In a full implementation, this would start parallel execution
291 # 完全な実装では、これは並列実行を開始する
292 ctx.goto(self.join_step)
294 return ctx
297class JoinStep(Step):
298 """
299 Step that joins results from parallel branches
300 並列ブランチからの結果を結合するステップ
302 This step waits for parallel branches to complete and merges their results.
303 このステップは並列ブランチの完了を待機し、結果をマージします。
304 """
306 def __init__(self, name: str, fork_step: str, join_type: str = "all", next_step: Optional[str] = None):
307 """
308 Initialize join step
309 ジョインステップを初期化
311 Args:
312 name: Step name / ステップ名
313 fork_step: Associated fork step name / 関連するフォークステップ名
314 join_type: Join type ("all" or "any") / ジョインタイプ("all"または"any")
315 next_step: Next step after join / ジョイン後の次ステップ
316 """
317 super().__init__(name)
318 self.fork_step = fork_step
319 self.join_type = join_type
320 self.next_step = next_step
322 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
323 """
324 Execute join step
325 ジョインステップを実行
327 Args:
328 user_input: User input / ユーザー入力
329 ctx: Current context / 現在のコンテキスト
331 Returns:
332 Context: Updated context / 更新済みコンテキスト
333 """
334 ctx.update_step_info(self.name)
336 # Get branch information from shared state
337 # 共有状態からブランチ情報を取得
338 branches = ctx.shared_state.get(f"{self.fork_step}_branches", [])
340 # For now, just mark as completed
341 # 現在のところ、完了としてマークするだけ
342 # In a full implementation, this would wait for and merge branch results
343 # 完全な実装では、これはブランチ結果を待機してマージする
344 ctx.add_system_message(f"Joined {len(branches)} branches using {self.join_type} strategy")
346 # Set next step if specified, otherwise finish the flow
347 # 指定されている場合は次ステップを設定、そうでなければフローを終了
348 if self.next_step:
349 ctx.goto(self.next_step)
350 else:
351 ctx.finish()
353 return ctx
357class DebugStep(Step):
358 """
359 Step for debugging and logging
360 デバッグとログ用ステップ
362 This step prints or logs context information for debugging purposes.
363 このステップはデバッグ目的でコンテキスト情報を印刷またはログ出力します。
364 """
366 def __init__(self, name: str, message: str = "", print_context: bool = False, next_step: Optional[str] = None):
367 """
368 Initialize debug step
369 デバッグステップを初期化
371 Args:
372 name: Step name / ステップ名
373 message: Debug message / デバッグメッセージ
374 print_context: Whether to print full context / 完全なコンテキストを印刷するか
375 next_step: Next step / 次ステップ
376 """
377 super().__init__(name)
378 self.message = message
379 self.print_context = print_context
380 self.next_step = next_step
382 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
383 """
384 Execute debug step
385 デバッグステップを実行
387 Args:
388 user_input: User input / ユーザー入力
389 ctx: Current context / 現在のコンテキスト
391 Returns:
392 Context: Updated context / 更新済みコンテキスト
393 """
394 ctx.update_step_info(self.name)
396 # Print debug information
397 # デバッグ情報を印刷
398 print(f"🐛 DEBUG [{self.name}]: {self.message}")
399 if user_input:
400 print(f" User Input: {user_input}")
401 print(f" Step Count: {ctx.step_count}")
402 print(f" Next Label: {ctx.next_label}")
404 if self.print_context:
405 print(f" Context: {ctx.model_dump()}")
407 # Add debug message to system messages
408 # デバッグメッセージをシステムメッセージに追加
409 ctx.add_system_message(f"DEBUG {self.name}: {self.message}")
411 # Set next step if specified, otherwise finish the flow
412 # 指定されている場合は次ステップを設定、そうでなければフローを終了
413 if self.next_step:
414 ctx.goto(self.next_step)
415 else:
416 ctx.finish()
418 return ctx
421# Utility functions for creating common step patterns
422# 一般的なステップパターンを作成するユーティリティ関数
424def create_simple_condition(field_path: str, expected_value: Any) -> Callable[[Context], bool]:
425 """
426 Create a simple condition function that checks a field value
427 フィールド値をチェックする簡単な条件関数を作成
429 Args:
430 field_path: Dot-separated path to field (e.g., "shared_state.status") / フィールドへのドット区切りパス
431 expected_value: Expected value / 期待値
433 Returns:
434 Callable[[Context], bool]: Condition function / 条件関数
435 """
436 def condition(ctx: Context) -> bool:
437 try:
438 # Navigate to the field using dot notation
439 # ドット記法を使用してフィールドに移動
440 obj = ctx
441 for part in field_path.split('.'):
442 if hasattr(obj, part):
443 obj = getattr(obj, part)
444 elif isinstance(obj, dict) and part in obj:
445 obj = obj[part]
446 else:
447 return False
448 return obj == expected_value
449 except Exception:
450 return False
452 return condition
455def create_lambda_step(name: str, func: Callable[[Context], Any], next_step: Optional[str] = None) -> FunctionStep:
456 """
457 Create a simple function step from a lambda
458 ラムダから簡単な関数ステップを作成
460 Args:
461 name: Step name / ステップ名
462 func: Function to execute / 実行する関数
463 next_step: Next step / 次ステップ
465 Returns:
466 FunctionStep: Function step / 関数ステップ
467 """
468 def wrapper(user_input: Optional[str], ctx: Context) -> Context:
469 func(ctx)
470 return ctx
472 return FunctionStep(name, wrapper, next_step)
475class ParallelStep(Step):
476 """
477 Step that executes multiple steps in parallel
478 複数のステップを並列実行するステップ
480 This step automatically manages parallel execution of child steps.
481 このステップは子ステップの並列実行を自動管理します。
482 It waits for all parallel steps to complete before proceeding.
483 全ての並列ステップが完了するまで待機してから進行します。
484 """
486 def __init__(
487 self,
488 name: str,
489 parallel_steps: List[Step],
490 next_step: Optional[str] = None,
491 max_workers: Optional[int] = None
492 ):
493 """
494 Initialize parallel step
495 並列ステップを初期化
497 Args:
498 name: Step name / ステップ名
499 parallel_steps: List of steps to execute in parallel / 並列実行するステップのリスト
500 next_step: Next step after all parallel steps complete / 全並列ステップ完了後の次ステップ
501 max_workers: Maximum number of concurrent workers / 最大同時ワーカー数
502 """
503 super().__init__(name)
504 self.parallel_steps = parallel_steps
505 self.next_step = next_step
506 self.max_workers = max_workers or min(32, len(parallel_steps) + 4)
508 # Validate that all steps have names
509 # 全ステップに名前があることを検証
510 for step in parallel_steps:
511 if not hasattr(step, 'name') or not step.name:
512 raise ValueError(f"All parallel steps must have valid names: {step}")
514 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
515 """
516 Execute parallel steps
517 並列ステップを実行
519 Args:
520 user_input: User input (passed to all parallel steps) / ユーザー入力(全並列ステップに渡される)
521 ctx: Current context / 現在のコンテキスト
523 Returns:
524 Context: Updated context with merged results / マージされた結果を持つ更新済みコンテキスト
525 """
526 ctx.update_step_info(self.name)
528 # Create separate contexts for each parallel step
529 # 各並列ステップ用に別々のコンテキストを作成
530 parallel_contexts = []
531 for step in self.parallel_steps:
532 # Clone context for each parallel execution
533 # 各並列実行用にコンテキストをクローン
534 step_ctx = self._clone_context_for_parallel(ctx, step.name)
535 parallel_contexts.append((step, step_ctx))
537 # Execute all steps in parallel
538 # 全ステップを並列実行
539 async def run_parallel_step(step_and_ctx):
540 step, step_ctx = step_and_ctx
541 try:
542 result_ctx = await step.run(user_input, step_ctx)
543 return step.name, result_ctx, None
544 except Exception as e:
545 return step.name, step_ctx, e
547 # Use asyncio.gather for parallel execution
548 # 並列実行にasyncio.gatherを使用
549 results = await asyncio.gather(
550 *[run_parallel_step(sc) for sc in parallel_contexts],
551 return_exceptions=True
552 )
554 # Merge results back into main context
555 # 結果をメインコンテキストにマージ
556 errors = []
557 for result in results:
558 if isinstance(result, Exception):
559 errors.append(result)
560 continue
562 step_name, result_ctx, error = result
563 if error:
564 errors.append(f"Step {step_name}: {error}")
565 continue
567 # Merge parallel step results
568 # 並列ステップ結果をマージ
569 self._merge_parallel_result(ctx, step_name, result_ctx)
571 # Handle errors if any
572 # エラーがあれば処理
573 if errors:
574 error_msg = f"Parallel execution errors: {'; '.join(map(str, errors))}"
575 ctx.add_system_message(error_msg)
576 raise RuntimeError(error_msg)
578 # Set next step or finish
579 # 次ステップを設定または終了
580 if self.next_step:
581 ctx.goto(self.next_step)
582 else:
583 ctx.finish()
585 return ctx
587 def _clone_context_for_parallel(self, ctx: Context, step_name: str) -> Context:
588 """
589 Clone context for parallel execution
590 並列実行用にコンテキストをクローン
592 Args:
593 ctx: Original context / 元のコンテキスト
594 step_name: Name of the step / ステップ名
596 Returns:
597 Context: Cloned context / クローンされたコンテキスト
598 """
599 # Create new context with shared state
600 # 共有状態を持つ新しいコンテキストを作成
601 cloned_ctx = Context()
603 # Copy essential state
604 # 必須状態をコピー
605 cloned_ctx.shared_state = ctx.shared_state.copy()
606 cloned_ctx.messages = ctx.messages.copy()
607 cloned_ctx.last_user_input = ctx.last_user_input
608 cloned_ctx.trace_id = ctx.trace_id
609 cloned_ctx.span_history = ctx.span_history.copy()
611 # Set step-specific information
612 # ステップ固有情報を設定
613 cloned_ctx.current_step = step_name
615 return cloned_ctx
617 def _merge_parallel_result(self, main_ctx: Context, step_name: str, result_ctx: Context) -> None:
618 """
619 Merge parallel step result into main context
620 並列ステップ結果をメインコンテキストにマージ
622 Args:
623 main_ctx: Main context / メインコンテキスト
624 step_name: Name of the completed step / 完了したステップ名
625 result_ctx: Result context from parallel step / 並列ステップからの結果コンテキスト
626 """
627 # Merge shared state with step-specific keys
628 # ステップ固有キーで共有状態をマージ
629 for key, value in result_ctx.shared_state.items():
630 if key not in main_ctx.shared_state:
631 main_ctx.shared_state[key] = value
632 else:
633 # Handle conflicts by prefixing with step name
634 # ステップ名をプレフィックスとして衝突を処理
635 prefixed_key = f"{step_name}_{key}"
636 main_ctx.shared_state[prefixed_key] = value
638 # Merge conversation history
639 # 会話履歴をマージ
640 main_ctx.messages.extend(result_ctx.messages)
642 # Update execution path
643 # 実行パスを更新
644 main_ctx.span_history.extend(result_ctx.span_history)
646 # Store step-specific results metadata (avoid overwriting user data)
647 # ステップ固有結果メタデータを保存(ユーザーデータの上書きを避ける)
648 main_ctx.shared_state[f"__{step_name}_metadata__"] = {
649 "status": "completed",
650 "output": result_ctx.shared_state,
651 "messages": result_ctx.messages
652 }