Coverage for src/refinire/agents/flow/step.py: 76%
214 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-15 18:51 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-15 18:51 +0900
1from __future__ import annotations
3"""Step — Step interface and basic implementations for Flow workflows.
5Stepはフローワークフロー用のステップインターフェースと基本実装を提供します。
6UserInputStep、ConditionStep、ForkStep、JoinStepなどの基本的なステップを含みます。
7"""
9import asyncio
10from abc import ABC, abstractmethod
11from typing import Any, Callable, Dict, List, Optional, Union, Awaitable
12from concurrent.futures import ThreadPoolExecutor
13import threading
15from .context import Context
18class Step(ABC):
19 """
20 Abstract base class for workflow steps
21 ワークフローステップの抽象基底クラス
23 All step implementations must provide:
24 全てのステップ実装は以下を提供する必要があります:
25 - name: Step identifier for DSL reference / DSL参照用ステップ識別子
26 - run: Async execution method / 非同期実行メソッド
27 """
29 def __init__(self, name: str):
30 """
31 Initialize step with name
32 名前でステップを初期化
34 Args:
35 name: Step name / ステップ名
36 """
37 self.name = name
39 @abstractmethod
40 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
41 """
42 Execute step and return updated context
43 ステップを実行し、更新されたコンテキストを返す
45 Args:
46 user_input: User input if any / ユーザー入力(あれば)
47 ctx: Current context / 現在のコンテキスト
49 Returns:
50 Context: Updated context with next_label set / next_labelが設定された更新済みコンテキスト
51 """
52 pass
54 def __str__(self) -> str:
55 return f"{self.__class__.__name__}({self.name})"
57 def __repr__(self) -> str:
58 return self.__str__()
61class UserInputStep(Step):
62 """
63 Step that waits for user input
64 ユーザー入力を待機するステップ
66 This step displays a prompt and waits for user response.
67 このステップはプロンプトを表示し、ユーザー応答を待機します。
68 It sets the context to waiting state and returns without advancing.
69 コンテキストを待機状態に設定し、進行せずに返します。
70 """
72 def __init__(self, name: str, prompt: str, next_step: Optional[str] = None):
73 """
74 Initialize user input step
75 ユーザー入力ステップを初期化
77 Args:
78 name: Step name / ステップ名
79 prompt: Prompt to display to user / ユーザーに表示するプロンプト
80 next_step: Next step after input (optional) / 入力後の次ステップ(オプション)
81 """
82 super().__init__(name)
83 self.prompt = prompt
84 self.next_step = next_step
86 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
87 """
88 Execute user input step
89 ユーザー入力ステップを実行
91 Args:
92 user_input: User input if available / 利用可能なユーザー入力
93 ctx: Current context / 現在のコンテキスト
95 Returns:
96 Context: Updated context / 更新済みコンテキスト
97 """
98 ctx.update_step_info(self.name)
100 # If user input is provided, process it
101 # ユーザー入力が提供されている場合、処理する
102 if user_input is not None:
103 ctx.provide_user_input(user_input)
104 if self.next_step:
105 ctx.goto(self.next_step)
106 else:
107 ctx.finish()
108 # Note: If next_step is None, flow will end
109 # 注:next_stepがNoneの場合、フローは終了
110 else:
111 # Set waiting state for user input
112 # ユーザー入力の待機状態を設定
113 ctx.set_waiting_for_user_input(self.prompt)
115 return ctx
118class ConditionStep(Step):
119 """
120 Step that performs conditional routing
121 条件付きルーティングを実行するステップ
123 This step evaluates a condition and routes to different steps based on the result.
124 このステップは条件を評価し、結果に基づいて異なるステップにルーティングします。
125 """
127 def __init__(
128 self,
129 name: str,
130 condition: Callable[[Context], Union[bool, Awaitable[bool]]],
131 if_true: str,
132 if_false: str
133 ):
134 """
135 Initialize condition step
136 条件ステップを初期化
138 Args:
139 name: Step name / ステップ名
140 condition: Condition function / 条件関数
141 if_true: Step to go if condition is True / 条件がTrueの場合のステップ
142 if_false: Step to go if condition is False / 条件がFalseの場合のステップ
143 """
144 super().__init__(name)
145 self.condition = condition
146 self.if_true = if_true
147 self.if_false = if_false
149 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
150 """
151 Execute condition step
152 条件ステップを実行
154 Args:
155 user_input: User input (not used) / ユーザー入力(使用されない)
156 ctx: Current context / 現在のコンテキスト
158 Returns:
159 Context: Updated context with routing / ルーティング付き更新済みコンテキスト
160 """
161 ctx.update_step_info(self.name)
163 # Evaluate condition (may be async)
164 # 条件を評価(非同期の可能性あり)
165 try:
166 result = self.condition(ctx)
167 if asyncio.iscoroutine(result):
168 result = await result
169 except Exception as e:
170 # On error, go to false branch
171 # エラー時はfalseブランチに進む
172 ctx.add_system_message(f"Condition evaluation error: {e}")
173 result = False
175 # Route based on condition result
176 # 条件結果に基づいてルーティング
177 next_step = self.if_true if result else self.if_false
178 ctx.goto(next_step)
180 return ctx
183class FunctionStep(Step):
184 """
185 Step that executes a custom function
186 カスタム関数を実行するステップ
188 This step allows executing arbitrary code within the workflow.
189 このステップはワークフロー内で任意のコードを実行できます。
190 """
192 def __init__(
193 self,
194 name: str,
195 function: Callable[[Optional[str], Context], Union[Context, Awaitable[Context]]],
196 next_step: Optional[str] = None
197 ):
198 """
199 Initialize function step
200 関数ステップを初期化
202 Args:
203 name: Step name / ステップ名
204 function: Function to execute / 実行する関数
205 next_step: Next step after execution / 実行後の次ステップ
206 """
207 super().__init__(name)
208 self.function = function
209 self.next_step = next_step
211 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
212 """
213 Execute function step
214 関数ステップを実行
216 Args:
217 user_input: User input / ユーザー入力
218 ctx: Current context / 現在のコンテキスト
220 Returns:
221 Context: Updated context / 更新済みコンテキスト
222 """
223 ctx.update_step_info(self.name)
225 try:
226 # Execute the function (may be async)
227 # 関数を実行(非同期の可能性あり)
228 result = self.function(user_input, ctx)
229 if asyncio.iscoroutine(result):
230 ctx = await result
231 else:
232 ctx = result
233 except Exception as e:
234 ctx.add_system_message(f"Function execution error in {self.name}: {e}")
236 # Set next step if specified, otherwise finish the flow
237 # 指定されている場合は次ステップを設定、そうでなければフローを終了
238 if self.next_step:
239 ctx.goto(self.next_step)
240 else:
241 ctx.finish()
243 return ctx
246class ForkStep(Step):
247 """
248 Step that executes multiple branches in parallel
249 複数のブランチを並列実行するステップ
251 This step starts multiple sub-flows concurrently and collects their results.
252 このステップは複数のサブフローを同時に開始し、結果を収集します。
253 """
255 def __init__(self, name: str, branches: List[str], join_step: str):
256 """
257 Initialize fork step
258 フォークステップを初期化
260 Args:
261 name: Step name / ステップ名
262 branches: List of branch step names to execute in parallel / 並列実行するブランチステップ名のリスト
263 join_step: Step to join results / 結果を結合するステップ
264 """
265 super().__init__(name)
266 self.branches = branches
267 self.join_step = join_step
269 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
270 """
271 Execute fork step
272 フォークステップを実行
274 Args:
275 user_input: User input / ユーザー入力
276 ctx: Current context / 現在のコンテキスト
278 Returns:
279 Context: Updated context / 更新済みコンテキスト
280 """
281 ctx.update_step_info(self.name)
283 # Store branch information for join step
284 # ジョインステップ用にブランチ情報を保存
285 ctx.shared_state[f"{self.name}_branches"] = self.branches
286 ctx.shared_state[f"{self.name}_started"] = True
288 # For now, just route to the join step
289 # 現在のところ、ジョインステップにルーティングするだけ
290 # In a full implementation, this would start parallel execution
291 # 完全な実装では、これは並列実行を開始する
292 ctx.goto(self.join_step)
294 return ctx
297class JoinStep(Step):
298 """
299 Step that joins results from parallel branches
300 並列ブランチからの結果を結合するステップ
302 This step waits for parallel branches to complete and merges their results.
303 このステップは並列ブランチの完了を待機し、結果をマージします。
304 """
306 def __init__(self, name: str, fork_step: str, join_type: str = "all", next_step: Optional[str] = None):
307 """
308 Initialize join step
309 ジョインステップを初期化
311 Args:
312 name: Step name / ステップ名
313 fork_step: Associated fork step name / 関連するフォークステップ名
314 join_type: Join type ("all" or "any") / ジョインタイプ("all"または"any")
315 next_step: Next step after join / ジョイン後の次ステップ
316 """
317 super().__init__(name)
318 self.fork_step = fork_step
319 self.join_type = join_type
320 self.next_step = next_step
322 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
323 """
324 Execute join step
325 ジョインステップを実行
327 Args:
328 user_input: User input / ユーザー入力
329 ctx: Current context / 現在のコンテキスト
331 Returns:
332 Context: Updated context / 更新済みコンテキスト
333 """
334 ctx.update_step_info(self.name)
336 # Get branch information from shared state
337 # 共有状態からブランチ情報を取得
338 branches = ctx.shared_state.get(f"{self.fork_step}_branches", [])
340 # For now, just mark as completed
341 # 現在のところ、完了としてマークするだけ
342 # In a full implementation, this would wait for and merge branch results
343 # 完全な実装では、これはブランチ結果を待機してマージする
344 ctx.add_system_message(f"Joined {len(branches)} branches using {self.join_type} strategy")
346 # Set next step if specified, otherwise finish the flow
347 # 指定されている場合は次ステップを設定、そうでなければフローを終了
348 if self.next_step:
349 ctx.goto(self.next_step)
350 else:
351 ctx.finish()
353 return ctx
356class AgentPipelineStep(Step):
357 """
358 Step that wraps AgentPipeline for use in Flow
359 FlowでAgentPipelineを使用するためのラッパーステップ
361 This step allows using existing AgentPipeline instances as flow steps.
362 このステップは既存のAgentPipelineインスタンスをフローステップとして使用できます。
363 """
365 def __init__(self, name: str, pipeline: Any, next_step: Optional[str] = None):
366 """
367 Initialize agent pipeline step
368 エージェントパイプラインステップを初期化
370 Args:
371 name: Step name / ステップ名
372 pipeline: AgentPipeline instance / AgentPipelineインスタンス
373 next_step: Next step after pipeline execution / パイプライン実行後の次ステップ
374 """
375 super().__init__(name)
376 self.pipeline = pipeline
377 self.next_step = next_step
379 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
380 """
381 Execute agent pipeline step
382 エージェントパイプラインステップを実行
384 Args:
385 user_input: User input / ユーザー入力
386 ctx: Current context / 現在のコンテキスト
388 Returns:
389 Context: Updated context / 更新済みコンテキスト
390 """
391 ctx.update_step_info(self.name)
393 try:
394 # Use the last user input if available
395 # 利用可能な場合は最後のユーザー入力を使用
396 input_text = user_input or ctx.last_user_input or ""
398 # Execute pipeline in thread pool to handle sync methods
399 # 同期メソッドを処理するためにスレッドプールでパイプラインを実行
400 loop = asyncio.get_event_loop()
401 with ThreadPoolExecutor() as executor:
402 future = loop.run_in_executor(executor, self.pipeline.run, input_text)
403 result = await future
405 # Store result in context
406 # 結果をコンテキストに保存
407 if result is not None:
408 ctx.prev_outputs[self.name] = result
409 ctx.add_assistant_message(str(result))
411 except Exception as e:
412 ctx.add_system_message(f"Pipeline execution error in {self.name}: {e}")
413 ctx.prev_outputs[self.name] = None
415 # Set next step if specified, otherwise finish the flow
416 # 指定されている場合は次ステップを設定、そうでなければフローを終了
417 if self.next_step:
418 ctx.goto(self.next_step)
419 else:
420 ctx.finish()
422 return ctx
425class DebugStep(Step):
426 """
427 Step for debugging and logging
428 デバッグとログ用ステップ
430 This step prints or logs context information for debugging purposes.
431 このステップはデバッグ目的でコンテキスト情報を印刷またはログ出力します。
432 """
434 def __init__(self, name: str, message: str = "", print_context: bool = False, next_step: Optional[str] = None):
435 """
436 Initialize debug step
437 デバッグステップを初期化
439 Args:
440 name: Step name / ステップ名
441 message: Debug message / デバッグメッセージ
442 print_context: Whether to print full context / 完全なコンテキストを印刷するか
443 next_step: Next step / 次ステップ
444 """
445 super().__init__(name)
446 self.message = message
447 self.print_context = print_context
448 self.next_step = next_step
450 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
451 """
452 Execute debug step
453 デバッグステップを実行
455 Args:
456 user_input: User input / ユーザー入力
457 ctx: Current context / 現在のコンテキスト
459 Returns:
460 Context: Updated context / 更新済みコンテキスト
461 """
462 ctx.update_step_info(self.name)
464 # Print debug information
465 # デバッグ情報を印刷
466 print(f"🐛 DEBUG [{self.name}]: {self.message}")
467 if user_input:
468 print(f" User Input: {user_input}")
469 print(f" Step Count: {ctx.step_count}")
470 print(f" Next Label: {ctx.next_label}")
472 if self.print_context:
473 print(f" Context: {ctx.model_dump()}")
475 # Add debug message to system messages
476 # デバッグメッセージをシステムメッセージに追加
477 ctx.add_system_message(f"DEBUG {self.name}: {self.message}")
479 # Set next step if specified, otherwise finish the flow
480 # 指定されている場合は次ステップを設定、そうでなければフローを終了
481 if self.next_step:
482 ctx.goto(self.next_step)
483 else:
484 ctx.finish()
486 return ctx
489# Utility functions for creating common step patterns
490# 一般的なステップパターンを作成するユーティリティ関数
492def create_simple_condition(field_path: str, expected_value: Any) -> Callable[[Context], bool]:
493 """
494 Create a simple condition function that checks a field value
495 フィールド値をチェックする簡単な条件関数を作成
497 Args:
498 field_path: Dot-separated path to field (e.g., "shared_state.status") / フィールドへのドット区切りパス
499 expected_value: Expected value / 期待値
501 Returns:
502 Callable[[Context], bool]: Condition function / 条件関数
503 """
504 def condition(ctx: Context) -> bool:
505 try:
506 # Navigate to the field using dot notation
507 # ドット記法を使用してフィールドに移動
508 obj = ctx
509 for part in field_path.split('.'):
510 if hasattr(obj, part):
511 obj = getattr(obj, part)
512 elif isinstance(obj, dict) and part in obj:
513 obj = obj[part]
514 else:
515 return False
516 return obj == expected_value
517 except Exception:
518 return False
520 return condition
523def create_lambda_step(name: str, func: Callable[[Context], Any], next_step: Optional[str] = None) -> FunctionStep:
524 """
525 Create a simple function step from a lambda
526 ラムダから簡単な関数ステップを作成
528 Args:
529 name: Step name / ステップ名
530 func: Function to execute / 実行する関数
531 next_step: Next step / 次ステップ
533 Returns:
534 FunctionStep: Function step / 関数ステップ
535 """
536 def wrapper(user_input: Optional[str], ctx: Context) -> Context:
537 func(ctx)
538 return ctx
540 return FunctionStep(name, wrapper, next_step)
543class ParallelStep(Step):
544 """
545 Step that executes multiple steps in parallel
546 複数のステップを並列実行するステップ
548 This step automatically manages parallel execution of child steps.
549 このステップは子ステップの並列実行を自動管理します。
550 It waits for all parallel steps to complete before proceeding.
551 全ての並列ステップが完了するまで待機してから進行します。
552 """
554 def __init__(
555 self,
556 name: str,
557 parallel_steps: List[Step],
558 next_step: Optional[str] = None,
559 max_workers: Optional[int] = None
560 ):
561 """
562 Initialize parallel step
563 並列ステップを初期化
565 Args:
566 name: Step name / ステップ名
567 parallel_steps: List of steps to execute in parallel / 並列実行するステップのリスト
568 next_step: Next step after all parallel steps complete / 全並列ステップ完了後の次ステップ
569 max_workers: Maximum number of concurrent workers / 最大同時ワーカー数
570 """
571 super().__init__(name)
572 self.parallel_steps = parallel_steps
573 self.next_step = next_step
574 self.max_workers = max_workers or min(32, len(parallel_steps) + 4)
576 # Validate that all steps have names
577 # 全ステップに名前があることを検証
578 for step in parallel_steps:
579 if not hasattr(step, 'name') or not step.name:
580 raise ValueError(f"All parallel steps must have valid names: {step}")
582 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
583 """
584 Execute parallel steps
585 並列ステップを実行
587 Args:
588 user_input: User input (passed to all parallel steps) / ユーザー入力(全並列ステップに渡される)
589 ctx: Current context / 現在のコンテキスト
591 Returns:
592 Context: Updated context with merged results / マージされた結果を持つ更新済みコンテキスト
593 """
594 ctx.update_step_info(self.name)
596 # Create separate contexts for each parallel step
597 # 各並列ステップ用に別々のコンテキストを作成
598 parallel_contexts = []
599 for step in self.parallel_steps:
600 # Clone context for each parallel execution
601 # 各並列実行用にコンテキストをクローン
602 step_ctx = self._clone_context_for_parallel(ctx, step.name)
603 parallel_contexts.append((step, step_ctx))
605 # Execute all steps in parallel
606 # 全ステップを並列実行
607 async def run_parallel_step(step_and_ctx):
608 step, step_ctx = step_and_ctx
609 try:
610 result_ctx = await step.run(user_input, step_ctx)
611 return step.name, result_ctx, None
612 except Exception as e:
613 return step.name, step_ctx, e
615 # Use asyncio.gather for parallel execution
616 # 並列実行にasyncio.gatherを使用
617 results = await asyncio.gather(
618 *[run_parallel_step(sc) for sc in parallel_contexts],
619 return_exceptions=True
620 )
622 # Merge results back into main context
623 # 結果をメインコンテキストにマージ
624 errors = []
625 for result in results:
626 if isinstance(result, Exception):
627 errors.append(result)
628 continue
630 step_name, result_ctx, error = result
631 if error:
632 errors.append(f"Step {step_name}: {error}")
633 continue
635 # Merge parallel step results
636 # 並列ステップ結果をマージ
637 self._merge_parallel_result(ctx, step_name, result_ctx)
639 # Handle errors if any
640 # エラーがあれば処理
641 if errors:
642 error_msg = f"Parallel execution errors: {'; '.join(map(str, errors))}"
643 ctx.add_system_message(error_msg)
644 raise RuntimeError(error_msg)
646 # Set next step or finish
647 # 次ステップを設定または終了
648 if self.next_step:
649 ctx.goto(self.next_step)
650 else:
651 ctx.finish()
653 return ctx
655 def _clone_context_for_parallel(self, ctx: Context, step_name: str) -> Context:
656 """
657 Clone context for parallel execution
658 並列実行用にコンテキストをクローン
660 Args:
661 ctx: Original context / 元のコンテキスト
662 step_name: Name of the step / ステップ名
664 Returns:
665 Context: Cloned context / クローンされたコンテキスト
666 """
667 # Create new context with shared state
668 # 共有状態を持つ新しいコンテキストを作成
669 cloned_ctx = Context()
671 # Copy essential state
672 # 必須状態をコピー
673 cloned_ctx.shared_state = ctx.shared_state.copy()
674 cloned_ctx.messages = ctx.messages.copy()
675 cloned_ctx.last_user_input = ctx.last_user_input
676 cloned_ctx.trace_id = ctx.trace_id
677 cloned_ctx.span_history = ctx.span_history.copy()
679 # Set step-specific information
680 # ステップ固有情報を設定
681 cloned_ctx.current_step = step_name
683 return cloned_ctx
685 def _merge_parallel_result(self, main_ctx: Context, step_name: str, result_ctx: Context) -> None:
686 """
687 Merge parallel step result into main context
688 並列ステップ結果をメインコンテキストにマージ
690 Args:
691 main_ctx: Main context / メインコンテキスト
692 step_name: Name of the completed step / 完了したステップ名
693 result_ctx: Result context from parallel step / 並列ステップからの結果コンテキスト
694 """
695 # Merge shared state with step-specific keys
696 # ステップ固有キーで共有状態をマージ
697 for key, value in result_ctx.shared_state.items():
698 if key not in main_ctx.shared_state:
699 main_ctx.shared_state[key] = value
700 else:
701 # Handle conflicts by prefixing with step name
702 # ステップ名をプレフィックスとして衝突を処理
703 prefixed_key = f"{step_name}_{key}"
704 main_ctx.shared_state[prefixed_key] = value
706 # Merge conversation history
707 # 会話履歴をマージ
708 main_ctx.messages.extend(result_ctx.messages)
710 # Update execution path
711 # 実行パスを更新
712 main_ctx.span_history.extend(result_ctx.span_history)
714 # Store step-specific results metadata (avoid overwriting user data)
715 # ステップ固有結果メタデータを保存(ユーザーデータの上書きを避ける)
716 main_ctx.shared_state[f"__{step_name}_metadata__"] = {
717 "status": "completed",
718 "output": result_ctx.shared_state,
719 "messages": result_ctx.messages
720 }