Coverage for src/refinire/agents/flow/step.py: 76%

214 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-15 18:51 +0900

1from __future__ import annotations 

2 

3"""Step — Step interface and basic implementations for Flow workflows. 

4 

5Stepはフローワークフロー用のステップインターフェースと基本実装を提供します。 

6UserInputStep、ConditionStep、ForkStep、JoinStepなどの基本的なステップを含みます。 

7""" 

8 

9import asyncio 

10from abc import ABC, abstractmethod 

11from typing import Any, Callable, Dict, List, Optional, Union, Awaitable 

12from concurrent.futures import ThreadPoolExecutor 

13import threading 

14 

15from .context import Context 

16 

17 

18class Step(ABC): 

19 """ 

20 Abstract base class for workflow steps 

21 ワークフローステップの抽象基底クラス 

22  

23 All step implementations must provide: 

24 全てのステップ実装は以下を提供する必要があります: 

25 - name: Step identifier for DSL reference / DSL参照用ステップ識別子 

26 - run: Async execution method / 非同期実行メソッド 

27 """ 

28 

29 def __init__(self, name: str): 

30 """ 

31 Initialize step with name 

32 名前でステップを初期化 

33  

34 Args: 

35 name: Step name / ステップ名 

36 """ 

37 self.name = name 

38 

39 @abstractmethod 

40 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

41 """ 

42 Execute step and return updated context 

43 ステップを実行し、更新されたコンテキストを返す 

44  

45 Args: 

46 user_input: User input if any / ユーザー入力(あれば) 

47 ctx: Current context / 現在のコンテキスト 

48  

49 Returns: 

50 Context: Updated context with next_label set / next_labelが設定された更新済みコンテキスト 

51 """ 

52 pass 

53 

54 def __str__(self) -> str: 

55 return f"{self.__class__.__name__}({self.name})" 

56 

57 def __repr__(self) -> str: 

58 return self.__str__() 

59 

60 

61class UserInputStep(Step): 

62 """ 

63 Step that waits for user input 

64 ユーザー入力を待機するステップ 

65  

66 This step displays a prompt and waits for user response. 

67 このステップはプロンプトを表示し、ユーザー応答を待機します。 

68 It sets the context to waiting state and returns without advancing. 

69 コンテキストを待機状態に設定し、進行せずに返します。 

70 """ 

71 

72 def __init__(self, name: str, prompt: str, next_step: Optional[str] = None): 

73 """ 

74 Initialize user input step 

75 ユーザー入力ステップを初期化 

76  

77 Args: 

78 name: Step name / ステップ名 

79 prompt: Prompt to display to user / ユーザーに表示するプロンプト 

80 next_step: Next step after input (optional) / 入力後の次ステップ(オプション) 

81 """ 

82 super().__init__(name) 

83 self.prompt = prompt 

84 self.next_step = next_step 

85 

86 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

87 """ 

88 Execute user input step 

89 ユーザー入力ステップを実行 

90  

91 Args: 

92 user_input: User input if available / 利用可能なユーザー入力 

93 ctx: Current context / 現在のコンテキスト 

94  

95 Returns: 

96 Context: Updated context / 更新済みコンテキスト 

97 """ 

98 ctx.update_step_info(self.name) 

99 

100 # If user input is provided, process it 

101 # ユーザー入力が提供されている場合、処理する 

102 if user_input is not None: 

103 ctx.provide_user_input(user_input) 

104 if self.next_step: 

105 ctx.goto(self.next_step) 

106 else: 

107 ctx.finish() 

108 # Note: If next_step is None, flow will end 

109 # 注:next_stepがNoneの場合、フローは終了 

110 else: 

111 # Set waiting state for user input 

112 # ユーザー入力の待機状態を設定 

113 ctx.set_waiting_for_user_input(self.prompt) 

114 

115 return ctx 

116 

117 

118class ConditionStep(Step): 

119 """ 

120 Step that performs conditional routing 

121 条件付きルーティングを実行するステップ 

122  

123 This step evaluates a condition and routes to different steps based on the result. 

124 このステップは条件を評価し、結果に基づいて異なるステップにルーティングします。 

125 """ 

126 

127 def __init__( 

128 self, 

129 name: str, 

130 condition: Callable[[Context], Union[bool, Awaitable[bool]]], 

131 if_true: str, 

132 if_false: str 

133 ): 

134 """ 

135 Initialize condition step 

136 条件ステップを初期化 

137  

138 Args: 

139 name: Step name / ステップ名 

140 condition: Condition function / 条件関数 

141 if_true: Step to go if condition is True / 条件がTrueの場合のステップ 

142 if_false: Step to go if condition is False / 条件がFalseの場合のステップ 

143 """ 

144 super().__init__(name) 

145 self.condition = condition 

146 self.if_true = if_true 

147 self.if_false = if_false 

148 

149 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

150 """ 

151 Execute condition step 

152 条件ステップを実行 

153  

154 Args: 

155 user_input: User input (not used) / ユーザー入力(使用されない) 

156 ctx: Current context / 現在のコンテキスト 

157  

158 Returns: 

159 Context: Updated context with routing / ルーティング付き更新済みコンテキスト 

160 """ 

161 ctx.update_step_info(self.name) 

162 

163 # Evaluate condition (may be async) 

164 # 条件を評価(非同期の可能性あり) 

165 try: 

166 result = self.condition(ctx) 

167 if asyncio.iscoroutine(result): 

168 result = await result 

169 except Exception as e: 

170 # On error, go to false branch 

171 # エラー時はfalseブランチに進む 

172 ctx.add_system_message(f"Condition evaluation error: {e}") 

173 result = False 

174 

175 # Route based on condition result 

176 # 条件結果に基づいてルーティング 

177 next_step = self.if_true if result else self.if_false 

178 ctx.goto(next_step) 

179 

180 return ctx 

181 

182 

183class FunctionStep(Step): 

184 """ 

185 Step that executes a custom function 

186 カスタム関数を実行するステップ 

187  

188 This step allows executing arbitrary code within the workflow. 

189 このステップはワークフロー内で任意のコードを実行できます。 

190 """ 

191 

192 def __init__( 

193 self, 

194 name: str, 

195 function: Callable[[Optional[str], Context], Union[Context, Awaitable[Context]]], 

196 next_step: Optional[str] = None 

197 ): 

198 """ 

199 Initialize function step 

200 関数ステップを初期化 

201  

202 Args: 

203 name: Step name / ステップ名 

204 function: Function to execute / 実行する関数 

205 next_step: Next step after execution / 実行後の次ステップ 

206 """ 

207 super().__init__(name) 

208 self.function = function 

209 self.next_step = next_step 

210 

211 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

212 """ 

213 Execute function step 

214 関数ステップを実行 

215  

216 Args: 

217 user_input: User input / ユーザー入力 

218 ctx: Current context / 現在のコンテキスト 

219  

220 Returns: 

221 Context: Updated context / 更新済みコンテキスト 

222 """ 

223 ctx.update_step_info(self.name) 

224 

225 try: 

226 # Execute the function (may be async) 

227 # 関数を実行(非同期の可能性あり) 

228 result = self.function(user_input, ctx) 

229 if asyncio.iscoroutine(result): 

230 ctx = await result 

231 else: 

232 ctx = result 

233 except Exception as e: 

234 ctx.add_system_message(f"Function execution error in {self.name}: {e}") 

235 

236 # Set next step if specified, otherwise finish the flow 

237 # 指定されている場合は次ステップを設定、そうでなければフローを終了 

238 if self.next_step: 

239 ctx.goto(self.next_step) 

240 else: 

241 ctx.finish() 

242 

243 return ctx 

244 

245 

246class ForkStep(Step): 

247 """ 

248 Step that executes multiple branches in parallel 

249 複数のブランチを並列実行するステップ 

250  

251 This step starts multiple sub-flows concurrently and collects their results. 

252 このステップは複数のサブフローを同時に開始し、結果を収集します。 

253 """ 

254 

255 def __init__(self, name: str, branches: List[str], join_step: str): 

256 """ 

257 Initialize fork step 

258 フォークステップを初期化 

259  

260 Args: 

261 name: Step name / ステップ名 

262 branches: List of branch step names to execute in parallel / 並列実行するブランチステップ名のリスト 

263 join_step: Step to join results / 結果を結合するステップ 

264 """ 

265 super().__init__(name) 

266 self.branches = branches 

267 self.join_step = join_step 

268 

269 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

270 """ 

271 Execute fork step 

272 フォークステップを実行 

273  

274 Args: 

275 user_input: User input / ユーザー入力 

276 ctx: Current context / 現在のコンテキスト 

277  

278 Returns: 

279 Context: Updated context / 更新済みコンテキスト 

280 """ 

281 ctx.update_step_info(self.name) 

282 

283 # Store branch information for join step 

284 # ジョインステップ用にブランチ情報を保存 

285 ctx.shared_state[f"{self.name}_branches"] = self.branches 

286 ctx.shared_state[f"{self.name}_started"] = True 

287 

288 # For now, just route to the join step 

289 # 現在のところ、ジョインステップにルーティングするだけ 

290 # In a full implementation, this would start parallel execution 

291 # 完全な実装では、これは並列実行を開始する 

292 ctx.goto(self.join_step) 

293 

294 return ctx 

295 

296 

297class JoinStep(Step): 

298 """ 

299 Step that joins results from parallel branches 

300 並列ブランチからの結果を結合するステップ 

301  

302 This step waits for parallel branches to complete and merges their results. 

303 このステップは並列ブランチの完了を待機し、結果をマージします。 

304 """ 

305 

306 def __init__(self, name: str, fork_step: str, join_type: str = "all", next_step: Optional[str] = None): 

307 """ 

308 Initialize join step 

309 ジョインステップを初期化 

310  

311 Args: 

312 name: Step name / ステップ名 

313 fork_step: Associated fork step name / 関連するフォークステップ名 

314 join_type: Join type ("all" or "any") / ジョインタイプ("all"または"any") 

315 next_step: Next step after join / ジョイン後の次ステップ 

316 """ 

317 super().__init__(name) 

318 self.fork_step = fork_step 

319 self.join_type = join_type 

320 self.next_step = next_step 

321 

322 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

323 """ 

324 Execute join step 

325 ジョインステップを実行 

326  

327 Args: 

328 user_input: User input / ユーザー入力 

329 ctx: Current context / 現在のコンテキスト 

330  

331 Returns: 

332 Context: Updated context / 更新済みコンテキスト 

333 """ 

334 ctx.update_step_info(self.name) 

335 

336 # Get branch information from shared state 

337 # 共有状態からブランチ情報を取得 

338 branches = ctx.shared_state.get(f"{self.fork_step}_branches", []) 

339 

340 # For now, just mark as completed 

341 # 現在のところ、完了としてマークするだけ 

342 # In a full implementation, this would wait for and merge branch results 

343 # 完全な実装では、これはブランチ結果を待機してマージする 

344 ctx.add_system_message(f"Joined {len(branches)} branches using {self.join_type} strategy") 

345 

346 # Set next step if specified, otherwise finish the flow 

347 # 指定されている場合は次ステップを設定、そうでなければフローを終了 

348 if self.next_step: 

349 ctx.goto(self.next_step) 

350 else: 

351 ctx.finish() 

352 

353 return ctx 

354 

355 

356class AgentPipelineStep(Step): 

357 """ 

358 Step that wraps AgentPipeline for use in Flow 

359 FlowでAgentPipelineを使用するためのラッパーステップ 

360  

361 This step allows using existing AgentPipeline instances as flow steps. 

362 このステップは既存のAgentPipelineインスタンスをフローステップとして使用できます。 

363 """ 

364 

365 def __init__(self, name: str, pipeline: Any, next_step: Optional[str] = None): 

366 """ 

367 Initialize agent pipeline step 

368 エージェントパイプラインステップを初期化 

369  

370 Args: 

371 name: Step name / ステップ名 

372 pipeline: AgentPipeline instance / AgentPipelineインスタンス 

373 next_step: Next step after pipeline execution / パイプライン実行後の次ステップ 

374 """ 

375 super().__init__(name) 

376 self.pipeline = pipeline 

377 self.next_step = next_step 

378 

379 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

380 """ 

381 Execute agent pipeline step 

382 エージェントパイプラインステップを実行 

383  

384 Args: 

385 user_input: User input / ユーザー入力 

386 ctx: Current context / 現在のコンテキスト 

387  

388 Returns: 

389 Context: Updated context / 更新済みコンテキスト 

390 """ 

391 ctx.update_step_info(self.name) 

392 

393 try: 

394 # Use the last user input if available 

395 # 利用可能な場合は最後のユーザー入力を使用 

396 input_text = user_input or ctx.last_user_input or "" 

397 

398 # Execute pipeline in thread pool to handle sync methods 

399 # 同期メソッドを処理するためにスレッドプールでパイプラインを実行 

400 loop = asyncio.get_event_loop() 

401 with ThreadPoolExecutor() as executor: 

402 future = loop.run_in_executor(executor, self.pipeline.run, input_text) 

403 result = await future 

404 

405 # Store result in context 

406 # 結果をコンテキストに保存 

407 if result is not None: 

408 ctx.prev_outputs[self.name] = result 

409 ctx.add_assistant_message(str(result)) 

410 

411 except Exception as e: 

412 ctx.add_system_message(f"Pipeline execution error in {self.name}: {e}") 

413 ctx.prev_outputs[self.name] = None 

414 

415 # Set next step if specified, otherwise finish the flow 

416 # 指定されている場合は次ステップを設定、そうでなければフローを終了 

417 if self.next_step: 

418 ctx.goto(self.next_step) 

419 else: 

420 ctx.finish() 

421 

422 return ctx 

423 

424 

425class DebugStep(Step): 

426 """ 

427 Step for debugging and logging 

428 デバッグとログ用ステップ 

429  

430 This step prints or logs context information for debugging purposes. 

431 このステップはデバッグ目的でコンテキスト情報を印刷またはログ出力します。 

432 """ 

433 

434 def __init__(self, name: str, message: str = "", print_context: bool = False, next_step: Optional[str] = None): 

435 """ 

436 Initialize debug step 

437 デバッグステップを初期化 

438  

439 Args: 

440 name: Step name / ステップ名 

441 message: Debug message / デバッグメッセージ 

442 print_context: Whether to print full context / 完全なコンテキストを印刷するか 

443 next_step: Next step / 次ステップ 

444 """ 

445 super().__init__(name) 

446 self.message = message 

447 self.print_context = print_context 

448 self.next_step = next_step 

449 

450 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

451 """ 

452 Execute debug step 

453 デバッグステップを実行 

454  

455 Args: 

456 user_input: User input / ユーザー入力 

457 ctx: Current context / 現在のコンテキスト 

458  

459 Returns: 

460 Context: Updated context / 更新済みコンテキスト 

461 """ 

462 ctx.update_step_info(self.name) 

463 

464 # Print debug information 

465 # デバッグ情報を印刷 

466 print(f"🐛 DEBUG [{self.name}]: {self.message}") 

467 if user_input: 

468 print(f" User Input: {user_input}") 

469 print(f" Step Count: {ctx.step_count}") 

470 print(f" Next Label: {ctx.next_label}") 

471 

472 if self.print_context: 

473 print(f" Context: {ctx.model_dump()}") 

474 

475 # Add debug message to system messages 

476 # デバッグメッセージをシステムメッセージに追加 

477 ctx.add_system_message(f"DEBUG {self.name}: {self.message}") 

478 

479 # Set next step if specified, otherwise finish the flow 

480 # 指定されている場合は次ステップを設定、そうでなければフローを終了 

481 if self.next_step: 

482 ctx.goto(self.next_step) 

483 else: 

484 ctx.finish() 

485 

486 return ctx 

487 

488 

489# Utility functions for creating common step patterns 

490# 一般的なステップパターンを作成するユーティリティ関数 

491 

492def create_simple_condition(field_path: str, expected_value: Any) -> Callable[[Context], bool]: 

493 """ 

494 Create a simple condition function that checks a field value 

495 フィールド値をチェックする簡単な条件関数を作成 

496  

497 Args: 

498 field_path: Dot-separated path to field (e.g., "shared_state.status") / フィールドへのドット区切りパス 

499 expected_value: Expected value / 期待値 

500  

501 Returns: 

502 Callable[[Context], bool]: Condition function / 条件関数 

503 """ 

504 def condition(ctx: Context) -> bool: 

505 try: 

506 # Navigate to the field using dot notation 

507 # ドット記法を使用してフィールドに移動 

508 obj = ctx 

509 for part in field_path.split('.'): 

510 if hasattr(obj, part): 

511 obj = getattr(obj, part) 

512 elif isinstance(obj, dict) and part in obj: 

513 obj = obj[part] 

514 else: 

515 return False 

516 return obj == expected_value 

517 except Exception: 

518 return False 

519 

520 return condition 

521 

522 

523def create_lambda_step(name: str, func: Callable[[Context], Any], next_step: Optional[str] = None) -> FunctionStep: 

524 """ 

525 Create a simple function step from a lambda 

526 ラムダから簡単な関数ステップを作成 

527  

528 Args: 

529 name: Step name / ステップ名 

530 func: Function to execute / 実行する関数 

531 next_step: Next step / 次ステップ 

532  

533 Returns: 

534 FunctionStep: Function step / 関数ステップ 

535 """ 

536 def wrapper(user_input: Optional[str], ctx: Context) -> Context: 

537 func(ctx) 

538 return ctx 

539 

540 return FunctionStep(name, wrapper, next_step) 

541 

542 

543class ParallelStep(Step): 

544 """ 

545 Step that executes multiple steps in parallel 

546 複数のステップを並列実行するステップ 

547  

548 This step automatically manages parallel execution of child steps. 

549 このステップは子ステップの並列実行を自動管理します。 

550 It waits for all parallel steps to complete before proceeding. 

551 全ての並列ステップが完了するまで待機してから進行します。 

552 """ 

553 

554 def __init__( 

555 self, 

556 name: str, 

557 parallel_steps: List[Step], 

558 next_step: Optional[str] = None, 

559 max_workers: Optional[int] = None 

560 ): 

561 """ 

562 Initialize parallel step 

563 並列ステップを初期化 

564  

565 Args: 

566 name: Step name / ステップ名 

567 parallel_steps: List of steps to execute in parallel / 並列実行するステップのリスト 

568 next_step: Next step after all parallel steps complete / 全並列ステップ完了後の次ステップ 

569 max_workers: Maximum number of concurrent workers / 最大同時ワーカー数 

570 """ 

571 super().__init__(name) 

572 self.parallel_steps = parallel_steps 

573 self.next_step = next_step 

574 self.max_workers = max_workers or min(32, len(parallel_steps) + 4) 

575 

576 # Validate that all steps have names 

577 # 全ステップに名前があることを検証 

578 for step in parallel_steps: 

579 if not hasattr(step, 'name') or not step.name: 

580 raise ValueError(f"All parallel steps must have valid names: {step}") 

581 

582 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

583 """ 

584 Execute parallel steps 

585 並列ステップを実行 

586  

587 Args: 

588 user_input: User input (passed to all parallel steps) / ユーザー入力(全並列ステップに渡される) 

589 ctx: Current context / 現在のコンテキスト 

590  

591 Returns: 

592 Context: Updated context with merged results / マージされた結果を持つ更新済みコンテキスト 

593 """ 

594 ctx.update_step_info(self.name) 

595 

596 # Create separate contexts for each parallel step 

597 # 各並列ステップ用に別々のコンテキストを作成 

598 parallel_contexts = [] 

599 for step in self.parallel_steps: 

600 # Clone context for each parallel execution 

601 # 各並列実行用にコンテキストをクローン 

602 step_ctx = self._clone_context_for_parallel(ctx, step.name) 

603 parallel_contexts.append((step, step_ctx)) 

604 

605 # Execute all steps in parallel 

606 # 全ステップを並列実行 

607 async def run_parallel_step(step_and_ctx): 

608 step, step_ctx = step_and_ctx 

609 try: 

610 result_ctx = await step.run(user_input, step_ctx) 

611 return step.name, result_ctx, None 

612 except Exception as e: 

613 return step.name, step_ctx, e 

614 

615 # Use asyncio.gather for parallel execution 

616 # 並列実行にasyncio.gatherを使用 

617 results = await asyncio.gather( 

618 *[run_parallel_step(sc) for sc in parallel_contexts], 

619 return_exceptions=True 

620 ) 

621 

622 # Merge results back into main context 

623 # 結果をメインコンテキストにマージ 

624 errors = [] 

625 for result in results: 

626 if isinstance(result, Exception): 

627 errors.append(result) 

628 continue 

629 

630 step_name, result_ctx, error = result 

631 if error: 

632 errors.append(f"Step {step_name}: {error}") 

633 continue 

634 

635 # Merge parallel step results 

636 # 並列ステップ結果をマージ 

637 self._merge_parallel_result(ctx, step_name, result_ctx) 

638 

639 # Handle errors if any 

640 # エラーがあれば処理 

641 if errors: 

642 error_msg = f"Parallel execution errors: {'; '.join(map(str, errors))}" 

643 ctx.add_system_message(error_msg) 

644 raise RuntimeError(error_msg) 

645 

646 # Set next step or finish 

647 # 次ステップを設定または終了 

648 if self.next_step: 

649 ctx.goto(self.next_step) 

650 else: 

651 ctx.finish() 

652 

653 return ctx 

654 

655 def _clone_context_for_parallel(self, ctx: Context, step_name: str) -> Context: 

656 """ 

657 Clone context for parallel execution 

658 並列実行用にコンテキストをクローン 

659  

660 Args: 

661 ctx: Original context / 元のコンテキスト 

662 step_name: Name of the step / ステップ名 

663  

664 Returns: 

665 Context: Cloned context / クローンされたコンテキスト 

666 """ 

667 # Create new context with shared state 

668 # 共有状態を持つ新しいコンテキストを作成 

669 cloned_ctx = Context() 

670 

671 # Copy essential state 

672 # 必須状態をコピー 

673 cloned_ctx.shared_state = ctx.shared_state.copy() 

674 cloned_ctx.messages = ctx.messages.copy() 

675 cloned_ctx.last_user_input = ctx.last_user_input 

676 cloned_ctx.trace_id = ctx.trace_id 

677 cloned_ctx.span_history = ctx.span_history.copy() 

678 

679 # Set step-specific information 

680 # ステップ固有情報を設定 

681 cloned_ctx.current_step = step_name 

682 

683 return cloned_ctx 

684 

685 def _merge_parallel_result(self, main_ctx: Context, step_name: str, result_ctx: Context) -> None: 

686 """ 

687 Merge parallel step result into main context 

688 並列ステップ結果をメインコンテキストにマージ 

689  

690 Args: 

691 main_ctx: Main context / メインコンテキスト 

692 step_name: Name of the completed step / 完了したステップ名 

693 result_ctx: Result context from parallel step / 並列ステップからの結果コンテキスト 

694 """ 

695 # Merge shared state with step-specific keys 

696 # ステップ固有キーで共有状態をマージ 

697 for key, value in result_ctx.shared_state.items(): 

698 if key not in main_ctx.shared_state: 

699 main_ctx.shared_state[key] = value 

700 else: 

701 # Handle conflicts by prefixing with step name 

702 # ステップ名をプレフィックスとして衝突を処理 

703 prefixed_key = f"{step_name}_{key}" 

704 main_ctx.shared_state[prefixed_key] = value 

705 

706 # Merge conversation history 

707 # 会話履歴をマージ 

708 main_ctx.messages.extend(result_ctx.messages) 

709 

710 # Update execution path 

711 # 実行パスを更新 

712 main_ctx.span_history.extend(result_ctx.span_history) 

713 

714 # Store step-specific results metadata (avoid overwriting user data) 

715 # ステップ固有結果メタデータを保存(ユーザーデータの上書きを避ける) 

716 main_ctx.shared_state[f"__{step_name}_metadata__"] = { 

717 "status": "completed", 

718 "output": result_ctx.shared_state, 

719 "messages": result_ctx.messages 

720 }