Coverage for src\agents_sdk_models\step.py: 80%

150 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-04 17:38 +0900

1from __future__ import annotations 

2 

3"""Step — Step interface and basic implementations for Flow workflows. 

4 

5Stepはフローワークフロー用のステップインターフェースと基本実装を提供します。 

6UserInputStep、ConditionStep、ForkStep、JoinStepなどの基本的なステップを含みます。 

7""" 

8 

9import asyncio 

10from abc import ABC, abstractmethod 

11from typing import Any, Callable, Dict, List, Optional, Union, Awaitable 

12from concurrent.futures import ThreadPoolExecutor 

13import threading 

14 

15from .context import Context 

16 

17 

18class Step(ABC): 

19 """ 

20 Abstract base class for workflow steps 

21 ワークフローステップの抽象基底クラス 

22  

23 All step implementations must provide: 

24 全てのステップ実装は以下を提供する必要があります: 

25 - name: Step identifier for DSL reference / DSL参照用ステップ識別子 

26 - run: Async execution method / 非同期実行メソッド 

27 """ 

28 

29 def __init__(self, name: str): 

30 """ 

31 Initialize step with name 

32 名前でステップを初期化 

33  

34 Args: 

35 name: Step name / ステップ名 

36 """ 

37 self.name = name 

38 

39 @abstractmethod 

40 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

41 """ 

42 Execute step and return updated context 

43 ステップを実行し、更新されたコンテキストを返す 

44  

45 Args: 

46 user_input: User input if any / ユーザー入力(あれば) 

47 ctx: Current context / 現在のコンテキスト 

48  

49 Returns: 

50 Context: Updated context with next_label set / next_labelが設定された更新済みコンテキスト 

51 """ 

52 pass 

53 

54 def __str__(self) -> str: 

55 return f"{self.__class__.__name__}({self.name})" 

56 

57 def __repr__(self) -> str: 

58 return self.__str__() 

59 

60 

61class UserInputStep(Step): 

62 """ 

63 Step that waits for user input 

64 ユーザー入力を待機するステップ 

65  

66 This step displays a prompt and waits for user response. 

67 このステップはプロンプトを表示し、ユーザー応答を待機します。 

68 It sets the context to waiting state and returns without advancing. 

69 コンテキストを待機状態に設定し、進行せずに返します。 

70 """ 

71 

72 def __init__(self, name: str, prompt: str, next_step: Optional[str] = None): 

73 """ 

74 Initialize user input step 

75 ユーザー入力ステップを初期化 

76  

77 Args: 

78 name: Step name / ステップ名 

79 prompt: Prompt to display to user / ユーザーに表示するプロンプト 

80 next_step: Next step after input (optional) / 入力後の次ステップ(オプション) 

81 """ 

82 super().__init__(name) 

83 self.prompt = prompt 

84 self.next_step = next_step 

85 

86 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

87 """ 

88 Execute user input step 

89 ユーザー入力ステップを実行 

90  

91 Args: 

92 user_input: User input if available / 利用可能なユーザー入力 

93 ctx: Current context / 現在のコンテキスト 

94  

95 Returns: 

96 Context: Updated context / 更新済みコンテキスト 

97 """ 

98 ctx.update_step_info(self.name) 

99 

100 # If user input is provided, process it 

101 # ユーザー入力が提供されている場合、処理する 

102 if user_input is not None: 

103 ctx.provide_user_input(user_input) 

104 if self.next_step: 

105 ctx.goto(self.next_step) 

106 # Note: If next_step is None, flow will end 

107 # 注:next_stepがNoneの場合、フローは終了 

108 else: 

109 # Set waiting state for user input 

110 # ユーザー入力の待機状態を設定 

111 ctx.set_waiting_for_user_input(self.prompt) 

112 

113 return ctx 

114 

115 

116class ConditionStep(Step): 

117 """ 

118 Step that performs conditional routing 

119 条件付きルーティングを実行するステップ 

120  

121 This step evaluates a condition and routes to different steps based on the result. 

122 このステップは条件を評価し、結果に基づいて異なるステップにルーティングします。 

123 """ 

124 

125 def __init__( 

126 self, 

127 name: str, 

128 condition: Callable[[Context], Union[bool, Awaitable[bool]]], 

129 if_true: str, 

130 if_false: str 

131 ): 

132 """ 

133 Initialize condition step 

134 条件ステップを初期化 

135  

136 Args: 

137 name: Step name / ステップ名 

138 condition: Condition function / 条件関数 

139 if_true: Step to go if condition is True / 条件がTrueの場合のステップ 

140 if_false: Step to go if condition is False / 条件がFalseの場合のステップ 

141 """ 

142 super().__init__(name) 

143 self.condition = condition 

144 self.if_true = if_true 

145 self.if_false = if_false 

146 

147 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

148 """ 

149 Execute condition step 

150 条件ステップを実行 

151  

152 Args: 

153 user_input: User input (not used) / ユーザー入力(使用されない) 

154 ctx: Current context / 現在のコンテキスト 

155  

156 Returns: 

157 Context: Updated context with routing / ルーティング付き更新済みコンテキスト 

158 """ 

159 ctx.update_step_info(self.name) 

160 

161 # Evaluate condition (may be async) 

162 # 条件を評価(非同期の可能性あり) 

163 try: 

164 result = self.condition(ctx) 

165 if asyncio.iscoroutine(result): 

166 result = await result 

167 except Exception as e: 

168 # On error, go to false branch 

169 # エラー時はfalseブランチに進む 

170 ctx.add_system_message(f"Condition evaluation error: {e}") 

171 result = False 

172 

173 # Route based on condition result 

174 # 条件結果に基づいてルーティング 

175 next_step = self.if_true if result else self.if_false 

176 ctx.goto(next_step) 

177 

178 return ctx 

179 

180 

181class FunctionStep(Step): 

182 """ 

183 Step that executes a custom function 

184 カスタム関数を実行するステップ 

185  

186 This step allows executing arbitrary code within the workflow. 

187 このステップはワークフロー内で任意のコードを実行できます。 

188 """ 

189 

190 def __init__( 

191 self, 

192 name: str, 

193 function: Callable[[Optional[str], Context], Union[Context, Awaitable[Context]]], 

194 next_step: Optional[str] = None 

195 ): 

196 """ 

197 Initialize function step 

198 関数ステップを初期化 

199  

200 Args: 

201 name: Step name / ステップ名 

202 function: Function to execute / 実行する関数 

203 next_step: Next step after execution / 実行後の次ステップ 

204 """ 

205 super().__init__(name) 

206 self.function = function 

207 self.next_step = next_step 

208 

209 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

210 """ 

211 Execute function step 

212 関数ステップを実行 

213  

214 Args: 

215 user_input: User input / ユーザー入力 

216 ctx: Current context / 現在のコンテキスト 

217  

218 Returns: 

219 Context: Updated context / 更新済みコンテキスト 

220 """ 

221 ctx.update_step_info(self.name) 

222 

223 try: 

224 # Execute the function (may be async) 

225 # 関数を実行(非同期の可能性あり) 

226 result = self.function(user_input, ctx) 

227 if asyncio.iscoroutine(result): 

228 ctx = await result 

229 else: 

230 ctx = result 

231 except Exception as e: 

232 ctx.add_system_message(f"Function execution error in {self.name}: {e}") 

233 

234 # Set next step if specified 

235 # 指定されている場合は次ステップを設定 

236 if self.next_step: 

237 ctx.goto(self.next_step) 

238 

239 return ctx 

240 

241 

242class ForkStep(Step): 

243 """ 

244 Step that executes multiple branches in parallel 

245 複数のブランチを並列実行するステップ 

246  

247 This step starts multiple sub-flows concurrently and collects their results. 

248 このステップは複数のサブフローを同時に開始し、結果を収集します。 

249 """ 

250 

251 def __init__(self, name: str, branches: List[str], join_step: str): 

252 """ 

253 Initialize fork step 

254 フォークステップを初期化 

255  

256 Args: 

257 name: Step name / ステップ名 

258 branches: List of branch step names to execute in parallel / 並列実行するブランチステップ名のリスト 

259 join_step: Step to join results / 結果を結合するステップ 

260 """ 

261 super().__init__(name) 

262 self.branches = branches 

263 self.join_step = join_step 

264 

265 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

266 """ 

267 Execute fork step 

268 フォークステップを実行 

269  

270 Args: 

271 user_input: User input / ユーザー入力 

272 ctx: Current context / 現在のコンテキスト 

273  

274 Returns: 

275 Context: Updated context / 更新済みコンテキスト 

276 """ 

277 ctx.update_step_info(self.name) 

278 

279 # Store branch information for join step 

280 # ジョインステップ用にブランチ情報を保存 

281 ctx.shared_state[f"{self.name}_branches"] = self.branches 

282 ctx.shared_state[f"{self.name}_started"] = True 

283 

284 # For now, just route to the join step 

285 # 現在のところ、ジョインステップにルーティングするだけ 

286 # In a full implementation, this would start parallel execution 

287 # 完全な実装では、これは並列実行を開始する 

288 ctx.goto(self.join_step) 

289 

290 return ctx 

291 

292 

293class JoinStep(Step): 

294 """ 

295 Step that joins results from parallel branches 

296 並列ブランチからの結果を結合するステップ 

297  

298 This step waits for parallel branches to complete and merges their results. 

299 このステップは並列ブランチの完了を待機し、結果をマージします。 

300 """ 

301 

302 def __init__(self, name: str, fork_step: str, join_type: str = "all", next_step: Optional[str] = None): 

303 """ 

304 Initialize join step 

305 ジョインステップを初期化 

306  

307 Args: 

308 name: Step name / ステップ名 

309 fork_step: Associated fork step name / 関連するフォークステップ名 

310 join_type: Join type ("all" or "any") / ジョインタイプ("all"または"any") 

311 next_step: Next step after join / ジョイン後の次ステップ 

312 """ 

313 super().__init__(name) 

314 self.fork_step = fork_step 

315 self.join_type = join_type 

316 self.next_step = next_step 

317 

318 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

319 """ 

320 Execute join step 

321 ジョインステップを実行 

322  

323 Args: 

324 user_input: User input / ユーザー入力 

325 ctx: Current context / 現在のコンテキスト 

326  

327 Returns: 

328 Context: Updated context / 更新済みコンテキスト 

329 """ 

330 ctx.update_step_info(self.name) 

331 

332 # Get branch information from shared state 

333 # 共有状態からブランチ情報を取得 

334 branches = ctx.shared_state.get(f"{self.fork_step}_branches", []) 

335 

336 # For now, just mark as completed 

337 # 現在のところ、完了としてマークするだけ 

338 # In a full implementation, this would wait for and merge branch results 

339 # 完全な実装では、これはブランチ結果を待機してマージする 

340 ctx.add_system_message(f"Joined {len(branches)} branches using {self.join_type} strategy") 

341 

342 # Set next step if specified 

343 # 指定されている場合は次ステップを設定 

344 if self.next_step: 

345 ctx.goto(self.next_step) 

346 

347 return ctx 

348 

349 

350class AgentPipelineStep(Step): 

351 """ 

352 Step that wraps AgentPipeline for use in Flow 

353 FlowでAgentPipelineを使用するためのラッパーステップ 

354  

355 This step allows using existing AgentPipeline instances as flow steps. 

356 このステップは既存のAgentPipelineインスタンスをフローステップとして使用できます。 

357 """ 

358 

359 def __init__(self, name: str, pipeline: Any, next_step: Optional[str] = None): 

360 """ 

361 Initialize agent pipeline step 

362 エージェントパイプラインステップを初期化 

363  

364 Args: 

365 name: Step name / ステップ名 

366 pipeline: AgentPipeline instance / AgentPipelineインスタンス 

367 next_step: Next step after pipeline execution / パイプライン実行後の次ステップ 

368 """ 

369 super().__init__(name) 

370 self.pipeline = pipeline 

371 self.next_step = next_step 

372 

373 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

374 """ 

375 Execute agent pipeline step 

376 エージェントパイプラインステップを実行 

377  

378 Args: 

379 user_input: User input / ユーザー入力 

380 ctx: Current context / 現在のコンテキスト 

381  

382 Returns: 

383 Context: Updated context / 更新済みコンテキスト 

384 """ 

385 ctx.update_step_info(self.name) 

386 

387 try: 

388 # Use the last user input if available 

389 # 利用可能な場合は最後のユーザー入力を使用 

390 input_text = user_input or ctx.last_user_input or "" 

391 

392 # Execute pipeline in thread pool to handle sync methods 

393 # 同期メソッドを処理するためにスレッドプールでパイプラインを実行 

394 loop = asyncio.get_event_loop() 

395 with ThreadPoolExecutor() as executor: 

396 future = loop.run_in_executor(executor, self.pipeline.run, input_text) 

397 result = await future 

398 

399 # Store result in context 

400 # 結果をコンテキストに保存 

401 if result is not None: 

402 ctx.prev_outputs[self.name] = result 

403 ctx.add_assistant_message(str(result)) 

404 

405 except Exception as e: 

406 ctx.add_system_message(f"Pipeline execution error in {self.name}: {e}") 

407 ctx.prev_outputs[self.name] = None 

408 

409 # Set next step if specified 

410 # 指定されている場合は次ステップを設定 

411 if self.next_step: 

412 ctx.goto(self.next_step) 

413 

414 return ctx 

415 

416 

417class DebugStep(Step): 

418 """ 

419 Step for debugging and logging 

420 デバッグとログ用ステップ 

421  

422 This step prints or logs context information for debugging purposes. 

423 このステップはデバッグ目的でコンテキスト情報を印刷またはログ出力します。 

424 """ 

425 

426 def __init__(self, name: str, message: str = "", print_context: bool = False, next_step: Optional[str] = None): 

427 """ 

428 Initialize debug step 

429 デバッグステップを初期化 

430  

431 Args: 

432 name: Step name / ステップ名 

433 message: Debug message / デバッグメッセージ 

434 print_context: Whether to print full context / 完全なコンテキストを印刷するか 

435 next_step: Next step / 次ステップ 

436 """ 

437 super().__init__(name) 

438 self.message = message 

439 self.print_context = print_context 

440 self.next_step = next_step 

441 

442 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

443 """ 

444 Execute debug step 

445 デバッグステップを実行 

446  

447 Args: 

448 user_input: User input / ユーザー入力 

449 ctx: Current context / 現在のコンテキスト 

450  

451 Returns: 

452 Context: Updated context / 更新済みコンテキスト 

453 """ 

454 ctx.update_step_info(self.name) 

455 

456 # Print debug information 

457 # デバッグ情報を印刷 

458 print(f"🐛 DEBUG [{self.name}]: {self.message}") 

459 if user_input: 

460 print(f" User Input: {user_input}") 

461 print(f" Step Count: {ctx.step_count}") 

462 print(f" Next Label: {ctx.next_label}") 

463 

464 if self.print_context: 

465 print(f" Context: {ctx.dict()}") 

466 

467 # Add debug message to system messages 

468 # デバッグメッセージをシステムメッセージに追加 

469 ctx.add_system_message(f"DEBUG {self.name}: {self.message}") 

470 

471 # Set next step if specified 

472 # 指定されている場合は次ステップを設定 

473 if self.next_step: 

474 ctx.goto(self.next_step) 

475 

476 return ctx 

477 

478 

479# Utility functions for creating common step patterns 

480# 一般的なステップパターンを作成するユーティリティ関数 

481 

482def create_simple_condition(field_path: str, expected_value: Any) -> Callable[[Context], bool]: 

483 """ 

484 Create a simple condition function that checks a field value 

485 フィールド値をチェックする簡単な条件関数を作成 

486  

487 Args: 

488 field_path: Dot-separated path to field (e.g., "shared_state.status") / フィールドへのドット区切りパス 

489 expected_value: Expected value / 期待値 

490  

491 Returns: 

492 Callable[[Context], bool]: Condition function / 条件関数 

493 """ 

494 def condition(ctx: Context) -> bool: 

495 try: 

496 # Navigate to the field using dot notation 

497 # ドット記法を使用してフィールドに移動 

498 obj = ctx 

499 for part in field_path.split('.'): 

500 if hasattr(obj, part): 

501 obj = getattr(obj, part) 

502 elif isinstance(obj, dict) and part in obj: 

503 obj = obj[part] 

504 else: 

505 return False 

506 return obj == expected_value 

507 except Exception: 

508 return False 

509 

510 return condition 

511 

512 

513def create_lambda_step(name: str, func: Callable[[Context], Any], next_step: Optional[str] = None) -> FunctionStep: 

514 """ 

515 Create a simple function step from a lambda 

516 ラムダから簡単な関数ステップを作成 

517  

518 Args: 

519 name: Step name / ステップ名 

520 func: Function to execute / 実行する関数 

521 next_step: Next step / 次ステップ 

522  

523 Returns: 

524 FunctionStep: Function step / 関数ステップ 

525 """ 

526 def wrapper(user_input: Optional[str], ctx: Context) -> Context: 

527 func(ctx) 

528 return ctx 

529 

530 return FunctionStep(name, wrapper, next_step)