Coverage for src/refinire/agents/pipeline/pipeline.py: 18%

196 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-15 18:51 +0900

1from __future__ import annotations 

2 

3"""Pipeline — ultra‑light builder for OpenAI Agents SDK. 

4 

5v1.5 — **Guardrails 対応** 

6 • 生成・評価それぞれに `generation_guardrails` / `evaluation_guardrails` を追加 

7 • `Agent(..., guardrails=…)` に注入して実行時に適用 

8""" 

9 

10from dataclasses import dataclass, is_dataclass 

11from typing import Callable, List, Dict, Any, Optional, Type 

12import json 

13import re 

14import textwrap # English: Import textwrap for dedenting multi-line JSON instruction strings. 日本語: JSON指示文字列の字下げを削除するためにtextwrapをインポートします。 

15from enum import Enum # English: Import Enum for defining comment importance levels. 日本語: コメント重要度レベル定義用Enumをインポートします。 

16 

17from agents import Agent, Runner 

18from ...core.llm import get_llm 

19from ...core.message import get_message # Import for localized messages 

20 

21try: 

22 from pydantic import BaseModel # type: ignore 

23except ImportError: 

24 BaseModel = object # type: ignore 

25 

26# English: Enum for comment importance levels. 

27# 日本語: コメントの重要度レベルを表す列挙型 

28class CommentImportance(Enum): 

29 SERIOUS = "serious" # English: Serious importance. 日本語: シリアス 

30 NORMAL = "normal" # English: Normal importance. 日本語: ノーマル 

31 MINOR = "minor" # English: Minor importance. 日本語: マイナー 

32 

33@dataclass 

34class Comment: 

35 """ 

36 Evaluation comment with importance and content 

37 評価コメントの重要度と内容を保持するクラス 

38 

39 Attributes: 

40 importance: Importance level of the comment (serious/normal/minor) / コメントの重要度レベル(シリアス/ノーマル/マイナー) 

41 content: Text content of the comment / コメント内容 

42 """ 

43 importance: CommentImportance # Importance level (serious/normal/minor) / 重要度レベル(シリアス/ノーマル/マイナー) 

44 content: str # Comment text / コメント内容 

45 

46@dataclass 

47class EvaluationResult: 

48 """ 

49 Result of evaluation for generated content 

50 生成されたコンテンツの評価結果を保持するクラス 

51 

52 Attributes: 

53 score: Evaluation score (0-100) / 評価スコア(0-100) 

54 comment: List of Comment instances containing importance and content / 重要度と内容を持つCommentクラスのリスト 

55 """ 

56 score: int # Evaluation score (0-100) / 評価スコア(0-100) 

57 comment: List[Comment] # List of evaluation comments / 評価コメントのリスト 

58 

59 

60class AgentPipeline: 

61 """ 

62 AgentPipeline class for managing the generation and evaluation of content using OpenAI Agents SDK 

63 OpenAI Agents SDKを使用してコンテンツの生成と評価を管理するパイプラインクラス 

64 

65 .. deprecated:: 0.0.22 

66 AgentPipeline is deprecated and will be removed in v0.1.0.  

67 Use GenAgent with Flow/Step architecture instead. 

68 See migration guide: docs/deprecation_plan.md 

69 

70 This class handles: 

71 このクラスは以下を処理します: 

72 - Content generation using instructions / instructionsを使用したコンテンツ生成 

73 - Content evaluation with scoring / スコアリングによるコンテンツ評価 

74 - Session history management / セッション履歴の管理 

75 - Output formatting and routing / 出力のフォーマットとルーティング 

76 

77 Preferred alternative: 

78 推奨代替手段: 

79 - Use GenAgent for single-step pipeline functionality 

80 - Use Flow/Step architecture for complex workflows 

81 - See examples/gen_agent_example.py for migration examples 

82 """ 

83 

84 def __init__( 

85 self, 

86 name: str, 

87 generation_instructions: str, 

88 evaluation_instructions: Optional[str], 

89 *, 

90 input_guardrails: Optional[list] = None, 

91 output_guardrails: Optional[list] = None, 

92 output_model: Optional[Type[Any]] = None, 

93 model: str | None = None, 

94 evaluation_model: str | None = None, 

95 generation_tools: Optional[list] = None, 

96 evaluation_tools: Optional[list] = None, 

97 routing_func: Optional[Callable[[Any], Any]] = None, 

98 session_history: Optional[list] = None, 

99 history_size: int = 10, 

100 threshold: int = 85, 

101 retries: int = 3, 

102 improvement_callback: Optional[Callable[[Any, EvaluationResult], None]] = None, 

103 dynamic_prompt: Optional[Callable[[str], str]] = None, 

104 retry_comment_importance: Optional[list[str]] = None, 

105 locale: str = "en", 

106 ) -> None: 

107 """ 

108 Initialize the Pipeline with configuration parameters 

109 設定パラメータでパイプラインを初期化する 

110 

111 .. deprecated:: 0.0.22 

112 AgentPipeline is deprecated and will be removed in v0.1.0.  

113 Use GenAgent with Flow/Step architecture instead. 

114 See migration guide: docs/deprecation_plan.md 

115 

116 Args: 

117 name: Pipeline name / パイプライン名 

118 generation_instructions: System prompt for generation / 生成用システムプロンプト 

119 evaluation_instructions: System prompt for evaluation / 評価用システムプロンプト 

120 input_guardrails: Guardrails for generation / 生成用ガードレール 

121 output_guardrails: Guardrails for evaluation / 評価用ガードレール 

122 output_model: Model for output formatting / 出力フォーマット用モデル 

123 model: LLM model name / LLMモデル名 

124 evaluation_model: Optional LLM model name for evaluation; if None, uses model. 日本語: 評価用のLLMモデル名(Noneの場合はmodelを使用) 

125 generation_tools: Tools for generation / 生成用ツール 

126 evaluation_tools: Tools for evaluation / 評価用ツール 

127 routing_func: Function for output routing / 出力ルーティング用関数 

128 session_history: Session history / セッション履歴 

129 history_size: Size of history to keep / 保持する履歴サイズ 

130 threshold: Evaluation score threshold / 評価スコアの閾値 

131 retries: Number of retry attempts / リトライ試行回数 

132 improvement_callback: Callback for improvement suggestions / 改善提案用コールバック 

133 dynamic_prompt: Optional function to dynamically build prompt / 動的プロンプト生成関数(任意) 

134 retry_comment_importance: Importance levels of comments to include on retry / リトライ時にプロンプトに含めるコメントの重大度レベル(任意) 

135 locale: Language code for localized messages ("en" or "ja") 

136 """ 

137 import warnings 

138 warnings.warn( 

139 "AgentPipeline is deprecated and will be removed in v0.1.0. " 

140 "Use GenAgent with Flow/Step architecture instead. " 

141 "See migration guide: docs/deprecation_plan.md", 

142 DeprecationWarning, 

143 stacklevel=2 

144 ) 

145 self.name = name 

146 self.generation_instructions = generation_instructions.strip() 

147 self.evaluation_instructions = evaluation_instructions.strip() if evaluation_instructions else None 

148 self.output_model = output_model 

149 

150 self.model = model 

151 self.evaluation_model = evaluation_model 

152 self.generation_tools = generation_tools or [] 

153 self.evaluation_tools = evaluation_tools or [] 

154 self.input_guardrails = input_guardrails or [] 

155 self.output_guardrails = output_guardrails or [] 

156 self.routing_func = routing_func 

157 self.session_history = session_history if session_history is not None else [] 

158 self.history_size = history_size 

159 self.threshold = threshold 

160 self.retries = retries 

161 self.improvement_callback = improvement_callback 

162 self.dynamic_prompt = dynamic_prompt 

163 self.retry_comment_importance = retry_comment_importance or [] 

164 # Language code for localized messages ("en" or "ja") 

165 self.locale = locale 

166 

167 # English: Get generation LLM instance; default tracing setting applied in get_llm 

168 # 日本語: 生成用LLMインスタンスを取得します。tracing設定はget_llm側でデフォルト値を使用 

169 llm = get_llm(model) if model else None 

170 # English: Determine evaluation LLM instance, fallback to generation model if evaluation_model is None 

171 # 日本語: 評価用LLMインスタンスを決定。evaluation_modelがNoneの場合は生成モデルを使用 

172 eval_source = evaluation_model if evaluation_model else model 

173 llm_eval = get_llm(eval_source) if eval_source else None 

174 

175 # Agents --------------------------------------------------------- 

176 self.gen_agent = Agent( 

177 name=f"{name}_generator", 

178 model=llm, 

179 tools=self.generation_tools, 

180 instructions=self.generation_instructions, 

181 input_guardrails=self.input_guardrails, 

182 ) 

183 

184 # Localized evaluation format instructions 

185 format_header = get_message("eval_output_format_header", self.locale) 

186 schema_instruction = get_message("eval_json_schema_instruction", self.locale) 

187 # JSON schema remains unlocalized 

188 json_schema = textwrap.dedent("""\ 

189 { 

190 "score": int(0~100), 

191 "comment": [ 

192 { 

193 "importance": "serious" | "normal" | "minor", # Importance field / 重要度フィールド 

194 "content": str # Comment content / コメント内容 

195 } 

196 ] 

197 } 

198 """) 

199 json_instr = "\n".join(["+----", format_header, schema_instruction, json_schema]) 

200 self.eval_agent = ( 

201 Agent( 

202 name=f"{name}_evaluator", 

203 model=llm_eval, 

204 tools=self.evaluation_tools, 

205 instructions=self.evaluation_instructions + json_instr, 

206 output_guardrails=self.output_guardrails, 

207 ) 

208 if self.evaluation_instructions 

209 else None 

210 ) 

211 

212 self._runner = Runner() 

213 self._pipeline_history: List[Dict[str, str]] = [] 

214 

215 # ------------------------------------------------------------------ 

216 # helpers 

217 # ------------------------------------------------------------------ 

218 

219 def _build_generation_prompt(self, user_input: str) -> str: 

220 """ 

221 Build the prompt for content generation 

222 コンテンツ生成用のプロンプトを構築する 

223 

224 Args: 

225 user_input: User input text / ユーザー入力テキスト 

226 

227 Returns: 

228 str: Formatted prompt for generation / 生成用のフォーマット済みプロンプト 

229 """ 

230 recent = "\n".join(f"User: {h['input']}\nAI: {h['output']}" 

231 for h in self._pipeline_history[-self.history_size:]) 

232 session = "\n".join(self.session_history) 

233 # Use localized prefix for user input 

234 prefix = get_message("user_input_prefix", self.locale) 

235 return "\n".join(filter(None, [session, recent, f"{prefix} {user_input}"])) 

236 

237 def _build_evaluation_prompt(self, user_input: str, generated_output: str) -> str: 

238 """ 

239 Build the prompt for content evaluation 

240 コンテンツ評価用のプロンプトを構築する 

241 

242 Args: 

243 user_input: Original user input / 元のユーザー入力 

244 generated_output: Generated content to evaluate / 評価対象の生成コンテンツ 

245 

246 Returns: 

247 str: Formatted prompt for evaluation / 評価用のフォーマット済みプロンプト 

248 """ 

249 parts = [] 

250 

251 # Add evaluation instructions if provided 

252 # 評価指示が提供されている場合は追加 

253 if self.evaluation_instructions: 

254 parts.append(self.evaluation_instructions) 

255 

256 parts.extend([ 

257 "----", 

258 f"ユーザー入力:\n{user_input}", 

259 "----", 

260 f"生成結果:\n{generated_output}", 

261 "上記を JSON で必ず次の形式にしてください" 

262 ]) 

263 return "\n".join(filter(None, parts)).strip() 

264 

265 @staticmethod 

266 def _extract_json(text: str) -> Dict[str, Any]: 

267 """ 

268 Extract JSON from text 

269 テキストからJSONを抽出する 

270 

271 Args: 

272 text: Text containing JSON / JSONを含むテキスト 

273 

274 Returns: 

275 Dict[str, Any]: Extracted JSON data / 抽出されたJSONデータ 

276 

277 Raises: 

278 ValueError: If JSON is not found in text / テキスト内にJSONが見つからない場合 

279 """ 

280 match = re.search(r"\{.*\}", text, re.S) 

281 if not match: 

282 raise ValueError("JSON not found in evaluation output") 

283 return json.loads(match.group(0)) 

284 

285 def _coerce_output(self, text: str): 

286 """ 

287 Convert output to specified model format 

288 出力を指定されたモデル形式に変換する 

289 

290 Args: 

291 text: Output text to convert / 変換対象の出力テキスト 

292 

293 Returns: 

294 Any: Converted output in specified format / 指定された形式の変換済み出力 

295 """ 

296 if self.output_model is None: 

297 return text 

298 try: 

299 data = json.loads(text) 

300 except json.JSONDecodeError: 

301 return text 

302 try: 

303 if isinstance(self.output_model, type) and issubclass(self.output_model, BaseModel): 

304 return self.output_model.model_validate(data) 

305 if is_dataclass(self.output_model): 

306 return self.output_model(**data) 

307 return self.output_model(**data) 

308 except Exception: 

309 return text 

310 

311 def _append_to_session(self, user_input: str, raw_output: str): 

312 """ 

313 Append interaction to session history 

314 セッション履歴にインタラクションを追加する 

315 

316 Args: 

317 user_input: User input text / ユーザー入力テキスト 

318 raw_output: Generated output text / 生成された出力テキスト 

319 """ 

320 if self.session_history is None: 

321 return 

322 self.session_history.append(f"User: {user_input}\nAI: {raw_output}") 

323 

324 def _route(self, parsed_output): 

325 """ 

326 Route the parsed output through routing function if specified 

327 指定されている場合、パース済み出力をルーティング関数で処理する 

328 

329 Args: 

330 parsed_output: Parsed output to route / ルーティング対象のパース済み出力 

331 

332 Returns: 

333 Any: Routed output / ルーティング済み出力 

334 """ 

335 return self.routing_func(parsed_output) if self.routing_func else parsed_output 

336 

337 # ------------------------------------------------------------------ 

338 # public 

339 # ------------------------------------------------------------------ 

340 

341 async def run_async(self, user_input: str): 

342 """ 

343 Run the pipeline asynchronously with user input 

344 ユーザー入力でパイプラインを非同期実行する 

345 

346 Args: 

347 user_input: User input text / ユーザー入力テキスト 

348 

349 Returns: 

350 Any: Processed output or None if evaluation fails / 処理済み出力、または評価失敗時はNone 

351 """ 

352 attempt = 0 

353 last_eval_result: Optional[EvaluationResult] = None # Store last evaluation result for retry 

354 while attempt <= self.retries: 

355 # ---------------- Generation ---------------- 

356 # On retry, include prior evaluation comments if configured 

357 if attempt > 0 and last_eval_result and self.retry_comment_importance: 

358 # Filter comments by importance 

359 try: 

360 comments = [c for c in last_eval_result.comment if c.get("importance") in self.retry_comment_importance] 

361 except Exception: 

362 comments = [] 

363 # Format serious comments with header 

364 # Localized header for evaluation feedback 

365 feedback_header = get_message("evaluation_feedback_header", self.locale) 

366 # English: Format each comment line. 日本語: 各コメント行をフォーマット 

367 formatted_comments = [f"- ({c.get('importance')}) {c.get('content')}" for c in comments] 

368 # English: Combine header and comment lines. 日本語: ヘッダーとコメント行を結合 

369 comment_block = "\n".join([feedback_header] + formatted_comments) 

370 else: 

371 comment_block = "" 

372 # Build base prompt 

373 if attempt > 0 and comment_block: 

374 if self.dynamic_prompt: 

375 # English: Use dynamic prompt if provided. 日本語: dynamic_promptがあればそれを使用 

376 gen_prompt = self.dynamic_prompt(user_input) 

377 else: 

378 # Localized header for AI history 

379 ai_history_header = get_message("ai_history_header", self.locale) 

380 # English: Extract AI outputs from pipeline history, omit user inputs. 日本語: パイプライン履歴からAIの出力のみ取得 

381 ai_outputs = "\n".join(h["output"] for h in self._pipeline_history[-self.history_size:]) 

382 # Localized prefix for user input line 

383 prefix = get_message("user_input_prefix", self.locale) 

384 # English: Current user input line. 日本語: 現在のユーザー入力行 

385 user_input_line = f"{prefix} {user_input}" 

386 # English: Combine AI outputs, feedback, and current user input. 日本語: AI出力、フィードバック、現在のユーザー入力を結合 

387 gen_prompt = "\n\n".join([ai_history_header, ai_outputs, comment_block, user_input_line]) 

388 else: 

389 if self.dynamic_prompt: 

390 gen_prompt = self.dynamic_prompt(user_input) 

391 else: 

392 gen_prompt = self._build_generation_prompt(user_input) 

393 

394 from agents import Runner 

395 gen_result = await Runner.run(self.gen_agent, gen_prompt) 

396 raw_output_text = getattr(gen_result, "final_output", str(gen_result)) 

397 if hasattr(gen_result, "tool_calls") and gen_result.tool_calls: 

398 raw_output_text = str(gen_result.tool_calls[0].call()) 

399 

400 parsed_output = self._coerce_output(raw_output_text) 

401 self._pipeline_history.append({"input": user_input, "output": raw_output_text}) 

402 

403 # ---------------- Evaluation ---------------- 

404 if not self.eval_agent: 

405 return self._route(parsed_output) 

406 

407 eval_prompt = self._build_evaluation_prompt(user_input, raw_output_text) 

408 

409 eval_raw = await Runner.run(self.eval_agent, eval_prompt) 

410 eval_text = getattr(eval_raw, "final_output", str(eval_raw)) 

411 try: 

412 eval_dict = self._extract_json(eval_text) 

413 eval_result = EvaluationResult(**eval_dict) 

414 except Exception: 

415 eval_result = EvaluationResult(score=0, comment=[Comment(importance=CommentImportance.SERIOUS, content="評価 JSON の解析に失敗")]) 

416 

417 if eval_result.score >= self.threshold: 

418 self._append_to_session(user_input, raw_output_text) 

419 return self._route(parsed_output) 

420 

421 # Store for next retry 

422 last_eval_result = eval_result 

423 attempt += 1 

424 

425 if self.improvement_callback: 

426 self.improvement_callback(parsed_output, eval_result) 

427 return None 

428 

429 def run(self, user_input: str): 

430 """ 

431 Run the pipeline with user input 

432 ユーザー入力でパイプラインを実行する 

433 

434 Args: 

435 user_input: User input text / ユーザー入力テキスト 

436 

437 Returns: 

438 Any: Processed output or None if evaluation fails / 処理済み出力、または評価失敗時はNone 

439 """ 

440 attempt = 0 

441 last_eval_result: Optional[EvaluationResult] = None # Store last evaluation result for retry 

442 while attempt <= self.retries: 

443 # ---------------- Generation ---------------- 

444 # On retry, include prior evaluation comments if configured 

445 if attempt > 0 and last_eval_result and self.retry_comment_importance: 

446 # Filter comments by importance 

447 try: 

448 comments = [c for c in last_eval_result.comment if c.get("importance") in self.retry_comment_importance] 

449 except Exception: 

450 comments = [] 

451 # Format serious comments with header 

452 # Localized header for evaluation feedback 

453 feedback_header = get_message("evaluation_feedback_header", self.locale) 

454 # English: Format each comment line. 日本語: 各コメント行をフォーマット 

455 formatted_comments = [f"- ({c.get('importance')}) {c.get('content')}" for c in comments] 

456 # English: Combine header and comment lines. 日本語: ヘッダーとコメント行を結合 

457 comment_block = "\n".join([feedback_header] + formatted_comments) 

458 else: 

459 comment_block = "" 

460 # Build base prompt 

461 if attempt > 0 and comment_block: 

462 if self.dynamic_prompt: 

463 # English: Use dynamic prompt if provided. 日本語: dynamic_promptがあればそれを使用 

464 gen_prompt = self.dynamic_prompt(user_input) 

465 else: 

466 # Localized header for AI history 

467 ai_history_header = get_message("ai_history_header", self.locale) 

468 # English: Extract AI outputs from pipeline history, omit user inputs. 日本語: パイプライン履歴からAIの出力のみ取得 

469 ai_outputs = "\n".join(h["output"] for h in self._pipeline_history[-self.history_size:]) 

470 # Localized prefix for user input line 

471 prefix = get_message("user_input_prefix", self.locale) 

472 # English: Current user input line. 日本語: 現在のユーザー入力行 

473 user_input_line = f"{prefix} {user_input}" 

474 # English: Combine AI outputs, feedback, and current user input. 日本語: AI出力、フィードバック、現在のユーザー入力を結合 

475 gen_prompt = "\n\n".join([ai_history_header, ai_outputs, comment_block, user_input_line]) 

476 else: 

477 if self.dynamic_prompt: 

478 gen_prompt = self.dynamic_prompt(user_input) 

479 else: 

480 gen_prompt = self._build_generation_prompt(user_input) 

481 

482 gen_result = self._runner.run_sync(self.gen_agent, gen_prompt) 

483 raw_output_text = getattr(gen_result, "final_output", str(gen_result)) 

484 if hasattr(gen_result, "tool_calls") and gen_result.tool_calls: 

485 raw_output_text = str(gen_result.tool_calls[0].call()) 

486 

487 parsed_output = self._coerce_output(raw_output_text) 

488 self._pipeline_history.append({"input": user_input, "output": raw_output_text}) 

489 

490 # ---------------- Evaluation ---------------- 

491 if not self.eval_agent: 

492 return self._route(parsed_output) 

493 

494 eval_prompt = self._build_evaluation_prompt(user_input, raw_output_text) 

495 

496 eval_raw = self._runner.run_sync(self.eval_agent, eval_prompt) 

497 eval_text = getattr(eval_raw, "final_output", str(eval_raw)) 

498 try: 

499 eval_dict = self._extract_json(eval_text) 

500 eval_result = EvaluationResult(**eval_dict) 

501 except Exception: 

502 eval_result = EvaluationResult(score=0, comment=[Comment(importance=CommentImportance.SERIOUS, content="評価 JSON の解析に失敗")]) 

503 

504 if eval_result.score >= self.threshold: 

505 self._append_to_session(user_input, raw_output_text) 

506 return self._route(parsed_output) 

507 

508 # Store for next retry 

509 last_eval_result = eval_result 

510 attempt += 1 

511 

512 if self.improvement_callback: 

513 self.improvement_callback(parsed_output, eval_result) 

514 return None