Coverage for src/refinire/agents/pipeline/pipeline.py: 18%
196 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-15 18:51 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-15 18:51 +0900
1from __future__ import annotations
3"""Pipeline — ultra‑light builder for OpenAI Agents SDK.
5v1.5 — **Guardrails 対応**
6 • 生成・評価それぞれに `generation_guardrails` / `evaluation_guardrails` を追加
7 • `Agent(..., guardrails=…)` に注入して実行時に適用
8"""
10from dataclasses import dataclass, is_dataclass
11from typing import Callable, List, Dict, Any, Optional, Type
12import json
13import re
14import textwrap # English: Import textwrap for dedenting multi-line JSON instruction strings. 日本語: JSON指示文字列の字下げを削除するためにtextwrapをインポートします。
15from enum import Enum # English: Import Enum for defining comment importance levels. 日本語: コメント重要度レベル定義用Enumをインポートします。
17from agents import Agent, Runner
18from ...core.llm import get_llm
19from ...core.message import get_message # Import for localized messages
21try:
22 from pydantic import BaseModel # type: ignore
23except ImportError:
24 BaseModel = object # type: ignore
26# English: Enum for comment importance levels.
27# 日本語: コメントの重要度レベルを表す列挙型
28class CommentImportance(Enum):
29 SERIOUS = "serious" # English: Serious importance. 日本語: シリアス
30 NORMAL = "normal" # English: Normal importance. 日本語: ノーマル
31 MINOR = "minor" # English: Minor importance. 日本語: マイナー
33@dataclass
34class Comment:
35 """
36 Evaluation comment with importance and content
37 評価コメントの重要度と内容を保持するクラス
39 Attributes:
40 importance: Importance level of the comment (serious/normal/minor) / コメントの重要度レベル(シリアス/ノーマル/マイナー)
41 content: Text content of the comment / コメント内容
42 """
43 importance: CommentImportance # Importance level (serious/normal/minor) / 重要度レベル(シリアス/ノーマル/マイナー)
44 content: str # Comment text / コメント内容
46@dataclass
47class EvaluationResult:
48 """
49 Result of evaluation for generated content
50 生成されたコンテンツの評価結果を保持するクラス
52 Attributes:
53 score: Evaluation score (0-100) / 評価スコア(0-100)
54 comment: List of Comment instances containing importance and content / 重要度と内容を持つCommentクラスのリスト
55 """
56 score: int # Evaluation score (0-100) / 評価スコア(0-100)
57 comment: List[Comment] # List of evaluation comments / 評価コメントのリスト
60class AgentPipeline:
61 """
62 AgentPipeline class for managing the generation and evaluation of content using OpenAI Agents SDK
63 OpenAI Agents SDKを使用してコンテンツの生成と評価を管理するパイプラインクラス
65 .. deprecated:: 0.0.22
66 AgentPipeline is deprecated and will be removed in v0.1.0.
67 Use GenAgent with Flow/Step architecture instead.
68 See migration guide: docs/deprecation_plan.md
70 This class handles:
71 このクラスは以下を処理します:
72 - Content generation using instructions / instructionsを使用したコンテンツ生成
73 - Content evaluation with scoring / スコアリングによるコンテンツ評価
74 - Session history management / セッション履歴の管理
75 - Output formatting and routing / 出力のフォーマットとルーティング
77 Preferred alternative:
78 推奨代替手段:
79 - Use GenAgent for single-step pipeline functionality
80 - Use Flow/Step architecture for complex workflows
81 - See examples/gen_agent_example.py for migration examples
82 """
84 def __init__(
85 self,
86 name: str,
87 generation_instructions: str,
88 evaluation_instructions: Optional[str],
89 *,
90 input_guardrails: Optional[list] = None,
91 output_guardrails: Optional[list] = None,
92 output_model: Optional[Type[Any]] = None,
93 model: str | None = None,
94 evaluation_model: str | None = None,
95 generation_tools: Optional[list] = None,
96 evaluation_tools: Optional[list] = None,
97 routing_func: Optional[Callable[[Any], Any]] = None,
98 session_history: Optional[list] = None,
99 history_size: int = 10,
100 threshold: int = 85,
101 retries: int = 3,
102 improvement_callback: Optional[Callable[[Any, EvaluationResult], None]] = None,
103 dynamic_prompt: Optional[Callable[[str], str]] = None,
104 retry_comment_importance: Optional[list[str]] = None,
105 locale: str = "en",
106 ) -> None:
107 """
108 Initialize the Pipeline with configuration parameters
109 設定パラメータでパイプラインを初期化する
111 .. deprecated:: 0.0.22
112 AgentPipeline is deprecated and will be removed in v0.1.0.
113 Use GenAgent with Flow/Step architecture instead.
114 See migration guide: docs/deprecation_plan.md
116 Args:
117 name: Pipeline name / パイプライン名
118 generation_instructions: System prompt for generation / 生成用システムプロンプト
119 evaluation_instructions: System prompt for evaluation / 評価用システムプロンプト
120 input_guardrails: Guardrails for generation / 生成用ガードレール
121 output_guardrails: Guardrails for evaluation / 評価用ガードレール
122 output_model: Model for output formatting / 出力フォーマット用モデル
123 model: LLM model name / LLMモデル名
124 evaluation_model: Optional LLM model name for evaluation; if None, uses model. 日本語: 評価用のLLMモデル名(Noneの場合はmodelを使用)
125 generation_tools: Tools for generation / 生成用ツール
126 evaluation_tools: Tools for evaluation / 評価用ツール
127 routing_func: Function for output routing / 出力ルーティング用関数
128 session_history: Session history / セッション履歴
129 history_size: Size of history to keep / 保持する履歴サイズ
130 threshold: Evaluation score threshold / 評価スコアの閾値
131 retries: Number of retry attempts / リトライ試行回数
132 improvement_callback: Callback for improvement suggestions / 改善提案用コールバック
133 dynamic_prompt: Optional function to dynamically build prompt / 動的プロンプト生成関数(任意)
134 retry_comment_importance: Importance levels of comments to include on retry / リトライ時にプロンプトに含めるコメントの重大度レベル(任意)
135 locale: Language code for localized messages ("en" or "ja")
136 """
137 import warnings
138 warnings.warn(
139 "AgentPipeline is deprecated and will be removed in v0.1.0. "
140 "Use GenAgent with Flow/Step architecture instead. "
141 "See migration guide: docs/deprecation_plan.md",
142 DeprecationWarning,
143 stacklevel=2
144 )
145 self.name = name
146 self.generation_instructions = generation_instructions.strip()
147 self.evaluation_instructions = evaluation_instructions.strip() if evaluation_instructions else None
148 self.output_model = output_model
150 self.model = model
151 self.evaluation_model = evaluation_model
152 self.generation_tools = generation_tools or []
153 self.evaluation_tools = evaluation_tools or []
154 self.input_guardrails = input_guardrails or []
155 self.output_guardrails = output_guardrails or []
156 self.routing_func = routing_func
157 self.session_history = session_history if session_history is not None else []
158 self.history_size = history_size
159 self.threshold = threshold
160 self.retries = retries
161 self.improvement_callback = improvement_callback
162 self.dynamic_prompt = dynamic_prompt
163 self.retry_comment_importance = retry_comment_importance or []
164 # Language code for localized messages ("en" or "ja")
165 self.locale = locale
167 # English: Get generation LLM instance; default tracing setting applied in get_llm
168 # 日本語: 生成用LLMインスタンスを取得します。tracing設定はget_llm側でデフォルト値を使用
169 llm = get_llm(model) if model else None
170 # English: Determine evaluation LLM instance, fallback to generation model if evaluation_model is None
171 # 日本語: 評価用LLMインスタンスを決定。evaluation_modelがNoneの場合は生成モデルを使用
172 eval_source = evaluation_model if evaluation_model else model
173 llm_eval = get_llm(eval_source) if eval_source else None
175 # Agents ---------------------------------------------------------
176 self.gen_agent = Agent(
177 name=f"{name}_generator",
178 model=llm,
179 tools=self.generation_tools,
180 instructions=self.generation_instructions,
181 input_guardrails=self.input_guardrails,
182 )
184 # Localized evaluation format instructions
185 format_header = get_message("eval_output_format_header", self.locale)
186 schema_instruction = get_message("eval_json_schema_instruction", self.locale)
187 # JSON schema remains unlocalized
188 json_schema = textwrap.dedent("""\
189 {
190 "score": int(0~100),
191 "comment": [
192 {
193 "importance": "serious" | "normal" | "minor", # Importance field / 重要度フィールド
194 "content": str # Comment content / コメント内容
195 }
196 ]
197 }
198 """)
199 json_instr = "\n".join(["+----", format_header, schema_instruction, json_schema])
200 self.eval_agent = (
201 Agent(
202 name=f"{name}_evaluator",
203 model=llm_eval,
204 tools=self.evaluation_tools,
205 instructions=self.evaluation_instructions + json_instr,
206 output_guardrails=self.output_guardrails,
207 )
208 if self.evaluation_instructions
209 else None
210 )
212 self._runner = Runner()
213 self._pipeline_history: List[Dict[str, str]] = []
215 # ------------------------------------------------------------------
216 # helpers
217 # ------------------------------------------------------------------
219 def _build_generation_prompt(self, user_input: str) -> str:
220 """
221 Build the prompt for content generation
222 コンテンツ生成用のプロンプトを構築する
224 Args:
225 user_input: User input text / ユーザー入力テキスト
227 Returns:
228 str: Formatted prompt for generation / 生成用のフォーマット済みプロンプト
229 """
230 recent = "\n".join(f"User: {h['input']}\nAI: {h['output']}"
231 for h in self._pipeline_history[-self.history_size:])
232 session = "\n".join(self.session_history)
233 # Use localized prefix for user input
234 prefix = get_message("user_input_prefix", self.locale)
235 return "\n".join(filter(None, [session, recent, f"{prefix} {user_input}"]))
237 def _build_evaluation_prompt(self, user_input: str, generated_output: str) -> str:
238 """
239 Build the prompt for content evaluation
240 コンテンツ評価用のプロンプトを構築する
242 Args:
243 user_input: Original user input / 元のユーザー入力
244 generated_output: Generated content to evaluate / 評価対象の生成コンテンツ
246 Returns:
247 str: Formatted prompt for evaluation / 評価用のフォーマット済みプロンプト
248 """
249 parts = []
251 # Add evaluation instructions if provided
252 # 評価指示が提供されている場合は追加
253 if self.evaluation_instructions:
254 parts.append(self.evaluation_instructions)
256 parts.extend([
257 "----",
258 f"ユーザー入力:\n{user_input}",
259 "----",
260 f"生成結果:\n{generated_output}",
261 "上記を JSON で必ず次の形式にしてください"
262 ])
263 return "\n".join(filter(None, parts)).strip()
265 @staticmethod
266 def _extract_json(text: str) -> Dict[str, Any]:
267 """
268 Extract JSON from text
269 テキストからJSONを抽出する
271 Args:
272 text: Text containing JSON / JSONを含むテキスト
274 Returns:
275 Dict[str, Any]: Extracted JSON data / 抽出されたJSONデータ
277 Raises:
278 ValueError: If JSON is not found in text / テキスト内にJSONが見つからない場合
279 """
280 match = re.search(r"\{.*\}", text, re.S)
281 if not match:
282 raise ValueError("JSON not found in evaluation output")
283 return json.loads(match.group(0))
285 def _coerce_output(self, text: str):
286 """
287 Convert output to specified model format
288 出力を指定されたモデル形式に変換する
290 Args:
291 text: Output text to convert / 変換対象の出力テキスト
293 Returns:
294 Any: Converted output in specified format / 指定された形式の変換済み出力
295 """
296 if self.output_model is None:
297 return text
298 try:
299 data = json.loads(text)
300 except json.JSONDecodeError:
301 return text
302 try:
303 if isinstance(self.output_model, type) and issubclass(self.output_model, BaseModel):
304 return self.output_model.model_validate(data)
305 if is_dataclass(self.output_model):
306 return self.output_model(**data)
307 return self.output_model(**data)
308 except Exception:
309 return text
311 def _append_to_session(self, user_input: str, raw_output: str):
312 """
313 Append interaction to session history
314 セッション履歴にインタラクションを追加する
316 Args:
317 user_input: User input text / ユーザー入力テキスト
318 raw_output: Generated output text / 生成された出力テキスト
319 """
320 if self.session_history is None:
321 return
322 self.session_history.append(f"User: {user_input}\nAI: {raw_output}")
324 def _route(self, parsed_output):
325 """
326 Route the parsed output through routing function if specified
327 指定されている場合、パース済み出力をルーティング関数で処理する
329 Args:
330 parsed_output: Parsed output to route / ルーティング対象のパース済み出力
332 Returns:
333 Any: Routed output / ルーティング済み出力
334 """
335 return self.routing_func(parsed_output) if self.routing_func else parsed_output
337 # ------------------------------------------------------------------
338 # public
339 # ------------------------------------------------------------------
341 async def run_async(self, user_input: str):
342 """
343 Run the pipeline asynchronously with user input
344 ユーザー入力でパイプラインを非同期実行する
346 Args:
347 user_input: User input text / ユーザー入力テキスト
349 Returns:
350 Any: Processed output or None if evaluation fails / 処理済み出力、または評価失敗時はNone
351 """
352 attempt = 0
353 last_eval_result: Optional[EvaluationResult] = None # Store last evaluation result for retry
354 while attempt <= self.retries:
355 # ---------------- Generation ----------------
356 # On retry, include prior evaluation comments if configured
357 if attempt > 0 and last_eval_result and self.retry_comment_importance:
358 # Filter comments by importance
359 try:
360 comments = [c for c in last_eval_result.comment if c.get("importance") in self.retry_comment_importance]
361 except Exception:
362 comments = []
363 # Format serious comments with header
364 # Localized header for evaluation feedback
365 feedback_header = get_message("evaluation_feedback_header", self.locale)
366 # English: Format each comment line. 日本語: 各コメント行をフォーマット
367 formatted_comments = [f"- ({c.get('importance')}) {c.get('content')}" for c in comments]
368 # English: Combine header and comment lines. 日本語: ヘッダーとコメント行を結合
369 comment_block = "\n".join([feedback_header] + formatted_comments)
370 else:
371 comment_block = ""
372 # Build base prompt
373 if attempt > 0 and comment_block:
374 if self.dynamic_prompt:
375 # English: Use dynamic prompt if provided. 日本語: dynamic_promptがあればそれを使用
376 gen_prompt = self.dynamic_prompt(user_input)
377 else:
378 # Localized header for AI history
379 ai_history_header = get_message("ai_history_header", self.locale)
380 # English: Extract AI outputs from pipeline history, omit user inputs. 日本語: パイプライン履歴からAIの出力のみ取得
381 ai_outputs = "\n".join(h["output"] for h in self._pipeline_history[-self.history_size:])
382 # Localized prefix for user input line
383 prefix = get_message("user_input_prefix", self.locale)
384 # English: Current user input line. 日本語: 現在のユーザー入力行
385 user_input_line = f"{prefix} {user_input}"
386 # English: Combine AI outputs, feedback, and current user input. 日本語: AI出力、フィードバック、現在のユーザー入力を結合
387 gen_prompt = "\n\n".join([ai_history_header, ai_outputs, comment_block, user_input_line])
388 else:
389 if self.dynamic_prompt:
390 gen_prompt = self.dynamic_prompt(user_input)
391 else:
392 gen_prompt = self._build_generation_prompt(user_input)
394 from agents import Runner
395 gen_result = await Runner.run(self.gen_agent, gen_prompt)
396 raw_output_text = getattr(gen_result, "final_output", str(gen_result))
397 if hasattr(gen_result, "tool_calls") and gen_result.tool_calls:
398 raw_output_text = str(gen_result.tool_calls[0].call())
400 parsed_output = self._coerce_output(raw_output_text)
401 self._pipeline_history.append({"input": user_input, "output": raw_output_text})
403 # ---------------- Evaluation ----------------
404 if not self.eval_agent:
405 return self._route(parsed_output)
407 eval_prompt = self._build_evaluation_prompt(user_input, raw_output_text)
409 eval_raw = await Runner.run(self.eval_agent, eval_prompt)
410 eval_text = getattr(eval_raw, "final_output", str(eval_raw))
411 try:
412 eval_dict = self._extract_json(eval_text)
413 eval_result = EvaluationResult(**eval_dict)
414 except Exception:
415 eval_result = EvaluationResult(score=0, comment=[Comment(importance=CommentImportance.SERIOUS, content="評価 JSON の解析に失敗")])
417 if eval_result.score >= self.threshold:
418 self._append_to_session(user_input, raw_output_text)
419 return self._route(parsed_output)
421 # Store for next retry
422 last_eval_result = eval_result
423 attempt += 1
425 if self.improvement_callback:
426 self.improvement_callback(parsed_output, eval_result)
427 return None
429 def run(self, user_input: str):
430 """
431 Run the pipeline with user input
432 ユーザー入力でパイプラインを実行する
434 Args:
435 user_input: User input text / ユーザー入力テキスト
437 Returns:
438 Any: Processed output or None if evaluation fails / 処理済み出力、または評価失敗時はNone
439 """
440 attempt = 0
441 last_eval_result: Optional[EvaluationResult] = None # Store last evaluation result for retry
442 while attempt <= self.retries:
443 # ---------------- Generation ----------------
444 # On retry, include prior evaluation comments if configured
445 if attempt > 0 and last_eval_result and self.retry_comment_importance:
446 # Filter comments by importance
447 try:
448 comments = [c for c in last_eval_result.comment if c.get("importance") in self.retry_comment_importance]
449 except Exception:
450 comments = []
451 # Format serious comments with header
452 # Localized header for evaluation feedback
453 feedback_header = get_message("evaluation_feedback_header", self.locale)
454 # English: Format each comment line. 日本語: 各コメント行をフォーマット
455 formatted_comments = [f"- ({c.get('importance')}) {c.get('content')}" for c in comments]
456 # English: Combine header and comment lines. 日本語: ヘッダーとコメント行を結合
457 comment_block = "\n".join([feedback_header] + formatted_comments)
458 else:
459 comment_block = ""
460 # Build base prompt
461 if attempt > 0 and comment_block:
462 if self.dynamic_prompt:
463 # English: Use dynamic prompt if provided. 日本語: dynamic_promptがあればそれを使用
464 gen_prompt = self.dynamic_prompt(user_input)
465 else:
466 # Localized header for AI history
467 ai_history_header = get_message("ai_history_header", self.locale)
468 # English: Extract AI outputs from pipeline history, omit user inputs. 日本語: パイプライン履歴からAIの出力のみ取得
469 ai_outputs = "\n".join(h["output"] for h in self._pipeline_history[-self.history_size:])
470 # Localized prefix for user input line
471 prefix = get_message("user_input_prefix", self.locale)
472 # English: Current user input line. 日本語: 現在のユーザー入力行
473 user_input_line = f"{prefix} {user_input}"
474 # English: Combine AI outputs, feedback, and current user input. 日本語: AI出力、フィードバック、現在のユーザー入力を結合
475 gen_prompt = "\n\n".join([ai_history_header, ai_outputs, comment_block, user_input_line])
476 else:
477 if self.dynamic_prompt:
478 gen_prompt = self.dynamic_prompt(user_input)
479 else:
480 gen_prompt = self._build_generation_prompt(user_input)
482 gen_result = self._runner.run_sync(self.gen_agent, gen_prompt)
483 raw_output_text = getattr(gen_result, "final_output", str(gen_result))
484 if hasattr(gen_result, "tool_calls") and gen_result.tool_calls:
485 raw_output_text = str(gen_result.tool_calls[0].call())
487 parsed_output = self._coerce_output(raw_output_text)
488 self._pipeline_history.append({"input": user_input, "output": raw_output_text})
490 # ---------------- Evaluation ----------------
491 if not self.eval_agent:
492 return self._route(parsed_output)
494 eval_prompt = self._build_evaluation_prompt(user_input, raw_output_text)
496 eval_raw = self._runner.run_sync(self.eval_agent, eval_prompt)
497 eval_text = getattr(eval_raw, "final_output", str(eval_raw))
498 try:
499 eval_dict = self._extract_json(eval_text)
500 eval_result = EvaluationResult(**eval_dict)
501 except Exception:
502 eval_result = EvaluationResult(score=0, comment=[Comment(importance=CommentImportance.SERIOUS, content="評価 JSON の解析に失敗")])
504 if eval_result.score >= self.threshold:
505 self._append_to_session(user_input, raw_output_text)
506 return self._route(parsed_output)
508 # Store for next retry
509 last_eval_result = eval_result
510 attempt += 1
512 if self.improvement_callback:
513 self.improvement_callback(parsed_output, eval_result)
514 return None