Coverage for src/refinire/agents/flow/step.py: 83%

191 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 15:27 +0900

1from __future__ import annotations 

2 

3"""Step — Step interface and basic implementations for Flow workflows. 

4 

5Stepはフローワークフロー用のステップインターフェースと基本実装を提供します。 

6UserInputStep、ConditionStep、ForkStep、JoinStepなどの基本的なステップを含みます。 

7""" 

8 

9import asyncio 

10from abc import ABC, abstractmethod 

11from typing import Any, Callable, Dict, List, Optional, Union, Awaitable 

12from concurrent.futures import ThreadPoolExecutor 

13import threading 

14 

15from .context import Context 

16 

17 

18class Step(ABC): 

19 """ 

20 Abstract base class for workflow steps 

21 ワークフローステップの抽象基底クラス 

22  

23 All step implementations must provide: 

24 全てのステップ実装は以下を提供する必要があります: 

25 - name: Step identifier for DSL reference / DSL参照用ステップ識別子 

26 - run: Async execution method / 非同期実行メソッド 

27 """ 

28 

29 def __init__(self, name: str): 

30 """ 

31 Initialize step with name 

32 名前でステップを初期化 

33  

34 Args: 

35 name: Step name / ステップ名 

36 """ 

37 self.name = name 

38 

39 @abstractmethod 

40 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

41 """ 

42 Execute step and return updated context 

43 ステップを実行し、更新されたコンテキストを返す 

44  

45 Args: 

46 user_input: User input if any / ユーザー入力(あれば) 

47 ctx: Current context / 現在のコンテキスト 

48  

49 Returns: 

50 Context: Updated context with next_label set / next_labelが設定された更新済みコンテキスト 

51 """ 

52 pass 

53 

54 def __str__(self) -> str: 

55 return f"{self.__class__.__name__}({self.name})" 

56 

57 def __repr__(self) -> str: 

58 return self.__str__() 

59 

60 

61class UserInputStep(Step): 

62 """ 

63 Step that waits for user input 

64 ユーザー入力を待機するステップ 

65  

66 This step displays a prompt and waits for user response. 

67 このステップはプロンプトを表示し、ユーザー応答を待機します。 

68 It sets the context to waiting state and returns without advancing. 

69 コンテキストを待機状態に設定し、進行せずに返します。 

70 """ 

71 

72 def __init__(self, name: str, prompt: str, next_step: Optional[str] = None): 

73 """ 

74 Initialize user input step 

75 ユーザー入力ステップを初期化 

76  

77 Args: 

78 name: Step name / ステップ名 

79 prompt: Prompt to display to user / ユーザーに表示するプロンプト 

80 next_step: Next step after input (optional) / 入力後の次ステップ(オプション) 

81 """ 

82 super().__init__(name) 

83 self.prompt = prompt 

84 self.next_step = next_step 

85 

86 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

87 """ 

88 Execute user input step 

89 ユーザー入力ステップを実行 

90  

91 Args: 

92 user_input: User input if available / 利用可能なユーザー入力 

93 ctx: Current context / 現在のコンテキスト 

94  

95 Returns: 

96 Context: Updated context / 更新済みコンテキスト 

97 """ 

98 ctx.update_step_info(self.name) 

99 

100 # If user input is provided, process it 

101 # ユーザー入力が提供されている場合、処理する 

102 if user_input is not None: 

103 ctx.provide_user_input(user_input) 

104 if self.next_step: 

105 ctx.goto(self.next_step) 

106 else: 

107 ctx.finish() 

108 # Note: If next_step is None, flow will end 

109 # 注:next_stepがNoneの場合、フローは終了 

110 else: 

111 # Set waiting state for user input 

112 # ユーザー入力の待機状態を設定 

113 ctx.set_waiting_for_user_input(self.prompt) 

114 

115 return ctx 

116 

117 

118class ConditionStep(Step): 

119 """ 

120 Step that performs conditional routing 

121 条件付きルーティングを実行するステップ 

122  

123 This step evaluates a condition and routes to different steps based on the result. 

124 このステップは条件を評価し、結果に基づいて異なるステップにルーティングします。 

125 """ 

126 

127 def __init__( 

128 self, 

129 name: str, 

130 condition: Callable[[Context], Union[bool, Awaitable[bool]]], 

131 if_true: str, 

132 if_false: str 

133 ): 

134 """ 

135 Initialize condition step 

136 条件ステップを初期化 

137  

138 Args: 

139 name: Step name / ステップ名 

140 condition: Condition function / 条件関数 

141 if_true: Step to go if condition is True / 条件がTrueの場合のステップ 

142 if_false: Step to go if condition is False / 条件がFalseの場合のステップ 

143 """ 

144 super().__init__(name) 

145 self.condition = condition 

146 self.if_true = if_true 

147 self.if_false = if_false 

148 

149 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

150 """ 

151 Execute condition step 

152 条件ステップを実行 

153  

154 Args: 

155 user_input: User input (not used) / ユーザー入力(使用されない) 

156 ctx: Current context / 現在のコンテキスト 

157  

158 Returns: 

159 Context: Updated context with routing / ルーティング付き更新済みコンテキスト 

160 """ 

161 ctx.update_step_info(self.name) 

162 

163 # Evaluate condition (may be async) 

164 # 条件を評価(非同期の可能性あり) 

165 try: 

166 result = self.condition(ctx) 

167 if asyncio.iscoroutine(result): 

168 result = await result 

169 except Exception as e: 

170 # On error, go to false branch 

171 # エラー時はfalseブランチに進む 

172 ctx.add_system_message(f"Condition evaluation error: {e}") 

173 result = False 

174 

175 # Route based on condition result 

176 # 条件結果に基づいてルーティング 

177 next_step = self.if_true if result else self.if_false 

178 ctx.goto(next_step) 

179 

180 return ctx 

181 

182 

183class FunctionStep(Step): 

184 """ 

185 Step that executes a custom function 

186 カスタム関数を実行するステップ 

187  

188 This step allows executing arbitrary code within the workflow. 

189 このステップはワークフロー内で任意のコードを実行できます。 

190 """ 

191 

192 def __init__( 

193 self, 

194 name: str, 

195 function: Callable[[Optional[str], Context], Union[Context, Awaitable[Context]]], 

196 next_step: Optional[str] = None 

197 ): 

198 """ 

199 Initialize function step 

200 関数ステップを初期化 

201  

202 Args: 

203 name: Step name / ステップ名 

204 function: Function to execute / 実行する関数 

205 next_step: Next step after execution / 実行後の次ステップ 

206 """ 

207 super().__init__(name) 

208 self.function = function 

209 self.next_step = next_step 

210 

211 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

212 """ 

213 Execute function step 

214 関数ステップを実行 

215  

216 Args: 

217 user_input: User input / ユーザー入力 

218 ctx: Current context / 現在のコンテキスト 

219  

220 Returns: 

221 Context: Updated context / 更新済みコンテキスト 

222 """ 

223 ctx.update_step_info(self.name) 

224 

225 try: 

226 # Execute the function (may be async) 

227 # 関数を実行(非同期の可能性あり) 

228 result = self.function(user_input, ctx) 

229 if asyncio.iscoroutine(result): 

230 ctx = await result 

231 else: 

232 ctx = result 

233 except Exception as e: 

234 ctx.add_system_message(f"Function execution error in {self.name}: {e}") 

235 

236 # Set next step if specified, otherwise finish the flow 

237 # 指定されている場合は次ステップを設定、そうでなければフローを終了 

238 if self.next_step: 

239 ctx.goto(self.next_step) 

240 else: 

241 ctx.finish() 

242 

243 return ctx 

244 

245 

246class ForkStep(Step): 

247 """ 

248 Step that executes multiple branches in parallel 

249 複数のブランチを並列実行するステップ 

250  

251 This step starts multiple sub-flows concurrently and collects their results. 

252 このステップは複数のサブフローを同時に開始し、結果を収集します。 

253 """ 

254 

255 def __init__(self, name: str, branches: List[str], join_step: str): 

256 """ 

257 Initialize fork step 

258 フォークステップを初期化 

259  

260 Args: 

261 name: Step name / ステップ名 

262 branches: List of branch step names to execute in parallel / 並列実行するブランチステップ名のリスト 

263 join_step: Step to join results / 結果を結合するステップ 

264 """ 

265 super().__init__(name) 

266 self.branches = branches 

267 self.join_step = join_step 

268 

269 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

270 """ 

271 Execute fork step 

272 フォークステップを実行 

273  

274 Args: 

275 user_input: User input / ユーザー入力 

276 ctx: Current context / 現在のコンテキスト 

277  

278 Returns: 

279 Context: Updated context / 更新済みコンテキスト 

280 """ 

281 ctx.update_step_info(self.name) 

282 

283 # Store branch information for join step 

284 # ジョインステップ用にブランチ情報を保存 

285 ctx.shared_state[f"{self.name}_branches"] = self.branches 

286 ctx.shared_state[f"{self.name}_started"] = True 

287 

288 # For now, just route to the join step 

289 # 現在のところ、ジョインステップにルーティングするだけ 

290 # In a full implementation, this would start parallel execution 

291 # 完全な実装では、これは並列実行を開始する 

292 ctx.goto(self.join_step) 

293 

294 return ctx 

295 

296 

297class JoinStep(Step): 

298 """ 

299 Step that joins results from parallel branches 

300 並列ブランチからの結果を結合するステップ 

301  

302 This step waits for parallel branches to complete and merges their results. 

303 このステップは並列ブランチの完了を待機し、結果をマージします。 

304 """ 

305 

306 def __init__(self, name: str, fork_step: str, join_type: str = "all", next_step: Optional[str] = None): 

307 """ 

308 Initialize join step 

309 ジョインステップを初期化 

310  

311 Args: 

312 name: Step name / ステップ名 

313 fork_step: Associated fork step name / 関連するフォークステップ名 

314 join_type: Join type ("all" or "any") / ジョインタイプ("all"または"any") 

315 next_step: Next step after join / ジョイン後の次ステップ 

316 """ 

317 super().__init__(name) 

318 self.fork_step = fork_step 

319 self.join_type = join_type 

320 self.next_step = next_step 

321 

322 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

323 """ 

324 Execute join step 

325 ジョインステップを実行 

326  

327 Args: 

328 user_input: User input / ユーザー入力 

329 ctx: Current context / 現在のコンテキスト 

330  

331 Returns: 

332 Context: Updated context / 更新済みコンテキスト 

333 """ 

334 ctx.update_step_info(self.name) 

335 

336 # Get branch information from shared state 

337 # 共有状態からブランチ情報を取得 

338 branches = ctx.shared_state.get(f"{self.fork_step}_branches", []) 

339 

340 # For now, just mark as completed 

341 # 現在のところ、完了としてマークするだけ 

342 # In a full implementation, this would wait for and merge branch results 

343 # 完全な実装では、これはブランチ結果を待機してマージする 

344 ctx.add_system_message(f"Joined {len(branches)} branches using {self.join_type} strategy") 

345 

346 # Set next step if specified, otherwise finish the flow 

347 # 指定されている場合は次ステップを設定、そうでなければフローを終了 

348 if self.next_step: 

349 ctx.goto(self.next_step) 

350 else: 

351 ctx.finish() 

352 

353 return ctx 

354 

355 

356 

357class DebugStep(Step): 

358 """ 

359 Step for debugging and logging 

360 デバッグとログ用ステップ 

361  

362 This step prints or logs context information for debugging purposes. 

363 このステップはデバッグ目的でコンテキスト情報を印刷またはログ出力します。 

364 """ 

365 

366 def __init__(self, name: str, message: str = "", print_context: bool = False, next_step: Optional[str] = None): 

367 """ 

368 Initialize debug step 

369 デバッグステップを初期化 

370  

371 Args: 

372 name: Step name / ステップ名 

373 message: Debug message / デバッグメッセージ 

374 print_context: Whether to print full context / 完全なコンテキストを印刷するか 

375 next_step: Next step / 次ステップ 

376 """ 

377 super().__init__(name) 

378 self.message = message 

379 self.print_context = print_context 

380 self.next_step = next_step 

381 

382 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

383 """ 

384 Execute debug step 

385 デバッグステップを実行 

386  

387 Args: 

388 user_input: User input / ユーザー入力 

389 ctx: Current context / 現在のコンテキスト 

390  

391 Returns: 

392 Context: Updated context / 更新済みコンテキスト 

393 """ 

394 ctx.update_step_info(self.name) 

395 

396 # Print debug information 

397 # デバッグ情報を印刷 

398 print(f"🐛 DEBUG [{self.name}]: {self.message}") 

399 if user_input: 

400 print(f" User Input: {user_input}") 

401 print(f" Step Count: {ctx.step_count}") 

402 print(f" Next Label: {ctx.next_label}") 

403 

404 if self.print_context: 

405 print(f" Context: {ctx.model_dump()}") 

406 

407 # Add debug message to system messages 

408 # デバッグメッセージをシステムメッセージに追加 

409 ctx.add_system_message(f"DEBUG {self.name}: {self.message}") 

410 

411 # Set next step if specified, otherwise finish the flow 

412 # 指定されている場合は次ステップを設定、そうでなければフローを終了 

413 if self.next_step: 

414 ctx.goto(self.next_step) 

415 else: 

416 ctx.finish() 

417 

418 return ctx 

419 

420 

421# Utility functions for creating common step patterns 

422# 一般的なステップパターンを作成するユーティリティ関数 

423 

424def create_simple_condition(field_path: str, expected_value: Any) -> Callable[[Context], bool]: 

425 """ 

426 Create a simple condition function that checks a field value 

427 フィールド値をチェックする簡単な条件関数を作成 

428  

429 Args: 

430 field_path: Dot-separated path to field (e.g., "shared_state.status") / フィールドへのドット区切りパス 

431 expected_value: Expected value / 期待値 

432  

433 Returns: 

434 Callable[[Context], bool]: Condition function / 条件関数 

435 """ 

436 def condition(ctx: Context) -> bool: 

437 try: 

438 # Navigate to the field using dot notation 

439 # ドット記法を使用してフィールドに移動 

440 obj = ctx 

441 for part in field_path.split('.'): 

442 if hasattr(obj, part): 

443 obj = getattr(obj, part) 

444 elif isinstance(obj, dict) and part in obj: 

445 obj = obj[part] 

446 else: 

447 return False 

448 return obj == expected_value 

449 except Exception: 

450 return False 

451 

452 return condition 

453 

454 

455def create_lambda_step(name: str, func: Callable[[Context], Any], next_step: Optional[str] = None) -> FunctionStep: 

456 """ 

457 Create a simple function step from a lambda 

458 ラムダから簡単な関数ステップを作成 

459  

460 Args: 

461 name: Step name / ステップ名 

462 func: Function to execute / 実行する関数 

463 next_step: Next step / 次ステップ 

464  

465 Returns: 

466 FunctionStep: Function step / 関数ステップ 

467 """ 

468 def wrapper(user_input: Optional[str], ctx: Context) -> Context: 

469 func(ctx) 

470 return ctx 

471 

472 return FunctionStep(name, wrapper, next_step) 

473 

474 

475class ParallelStep(Step): 

476 """ 

477 Step that executes multiple steps in parallel 

478 複数のステップを並列実行するステップ 

479  

480 This step automatically manages parallel execution of child steps. 

481 このステップは子ステップの並列実行を自動管理します。 

482 It waits for all parallel steps to complete before proceeding. 

483 全ての並列ステップが完了するまで待機してから進行します。 

484 """ 

485 

486 def __init__( 

487 self, 

488 name: str, 

489 parallel_steps: List[Step], 

490 next_step: Optional[str] = None, 

491 max_workers: Optional[int] = None 

492 ): 

493 """ 

494 Initialize parallel step 

495 並列ステップを初期化 

496  

497 Args: 

498 name: Step name / ステップ名 

499 parallel_steps: List of steps to execute in parallel / 並列実行するステップのリスト 

500 next_step: Next step after all parallel steps complete / 全並列ステップ完了後の次ステップ 

501 max_workers: Maximum number of concurrent workers / 最大同時ワーカー数 

502 """ 

503 super().__init__(name) 

504 self.parallel_steps = parallel_steps 

505 self.next_step = next_step 

506 self.max_workers = max_workers or min(32, len(parallel_steps) + 4) 

507 

508 # Validate that all steps have names 

509 # 全ステップに名前があることを検証 

510 for step in parallel_steps: 

511 if not hasattr(step, 'name') or not step.name: 

512 raise ValueError(f"All parallel steps must have valid names: {step}") 

513 

514 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

515 """ 

516 Execute parallel steps 

517 並列ステップを実行 

518  

519 Args: 

520 user_input: User input (passed to all parallel steps) / ユーザー入力(全並列ステップに渡される) 

521 ctx: Current context / 現在のコンテキスト 

522  

523 Returns: 

524 Context: Updated context with merged results / マージされた結果を持つ更新済みコンテキスト 

525 """ 

526 ctx.update_step_info(self.name) 

527 

528 # Create separate contexts for each parallel step 

529 # 各並列ステップ用に別々のコンテキストを作成 

530 parallel_contexts = [] 

531 for step in self.parallel_steps: 

532 # Clone context for each parallel execution 

533 # 各並列実行用にコンテキストをクローン 

534 step_ctx = self._clone_context_for_parallel(ctx, step.name) 

535 parallel_contexts.append((step, step_ctx)) 

536 

537 # Execute all steps in parallel 

538 # 全ステップを並列実行 

539 async def run_parallel_step(step_and_ctx): 

540 step, step_ctx = step_and_ctx 

541 try: 

542 result_ctx = await step.run(user_input, step_ctx) 

543 return step.name, result_ctx, None 

544 except Exception as e: 

545 return step.name, step_ctx, e 

546 

547 # Use asyncio.gather for parallel execution 

548 # 並列実行にasyncio.gatherを使用 

549 results = await asyncio.gather( 

550 *[run_parallel_step(sc) for sc in parallel_contexts], 

551 return_exceptions=True 

552 ) 

553 

554 # Merge results back into main context 

555 # 結果をメインコンテキストにマージ 

556 errors = [] 

557 for result in results: 

558 if isinstance(result, Exception): 

559 errors.append(result) 

560 continue 

561 

562 step_name, result_ctx, error = result 

563 if error: 

564 errors.append(f"Step {step_name}: {error}") 

565 continue 

566 

567 # Merge parallel step results 

568 # 並列ステップ結果をマージ 

569 self._merge_parallel_result(ctx, step_name, result_ctx) 

570 

571 # Handle errors if any 

572 # エラーがあれば処理 

573 if errors: 

574 error_msg = f"Parallel execution errors: {'; '.join(map(str, errors))}" 

575 ctx.add_system_message(error_msg) 

576 raise RuntimeError(error_msg) 

577 

578 # Set next step or finish 

579 # 次ステップを設定または終了 

580 if self.next_step: 

581 ctx.goto(self.next_step) 

582 else: 

583 ctx.finish() 

584 

585 return ctx 

586 

587 def _clone_context_for_parallel(self, ctx: Context, step_name: str) -> Context: 

588 """ 

589 Clone context for parallel execution 

590 並列実行用にコンテキストをクローン 

591  

592 Args: 

593 ctx: Original context / 元のコンテキスト 

594 step_name: Name of the step / ステップ名 

595  

596 Returns: 

597 Context: Cloned context / クローンされたコンテキスト 

598 """ 

599 # Create new context with shared state 

600 # 共有状態を持つ新しいコンテキストを作成 

601 cloned_ctx = Context() 

602 

603 # Copy essential state 

604 # 必須状態をコピー 

605 cloned_ctx.shared_state = ctx.shared_state.copy() 

606 cloned_ctx.messages = ctx.messages.copy() 

607 cloned_ctx.last_user_input = ctx.last_user_input 

608 cloned_ctx.trace_id = ctx.trace_id 

609 cloned_ctx.span_history = ctx.span_history.copy() 

610 

611 # Set step-specific information 

612 # ステップ固有情報を設定 

613 cloned_ctx.current_step = step_name 

614 

615 return cloned_ctx 

616 

617 def _merge_parallel_result(self, main_ctx: Context, step_name: str, result_ctx: Context) -> None: 

618 """ 

619 Merge parallel step result into main context 

620 並列ステップ結果をメインコンテキストにマージ 

621  

622 Args: 

623 main_ctx: Main context / メインコンテキスト 

624 step_name: Name of the completed step / 完了したステップ名 

625 result_ctx: Result context from parallel step / 並列ステップからの結果コンテキスト 

626 """ 

627 # Merge shared state with step-specific keys 

628 # ステップ固有キーで共有状態をマージ 

629 for key, value in result_ctx.shared_state.items(): 

630 if key not in main_ctx.shared_state: 

631 main_ctx.shared_state[key] = value 

632 else: 

633 # Handle conflicts by prefixing with step name 

634 # ステップ名をプレフィックスとして衝突を処理 

635 prefixed_key = f"{step_name}_{key}" 

636 main_ctx.shared_state[prefixed_key] = value 

637 

638 # Merge conversation history 

639 # 会話履歴をマージ 

640 main_ctx.messages.extend(result_ctx.messages) 

641 

642 # Update execution path 

643 # 実行パスを更新 

644 main_ctx.span_history.extend(result_ctx.span_history) 

645 

646 # Store step-specific results metadata (avoid overwriting user data) 

647 # ステップ固有結果メタデータを保存(ユーザーデータの上書きを避ける) 

648 main_ctx.shared_state[f"__{step_name}_metadata__"] = { 

649 "status": "completed", 

650 "output": result_ctx.shared_state, 

651 "messages": result_ctx.messages 

652 }