Coverage for src/refinire/agents/router.py: 90%
133 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 15:27 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 15:27 +0900
1"""
2RouterAgent implementation for routing inputs based on classification.
4The RouterAgent analyzes input data and routes it to appropriate processing paths
5based on configurable routing logic and classification results.
6"""
7# RouterAgent implementation for routing inputs based on classification.
8# RouterAgentは入力データを分析し、設定可能なルーティングロジックと分類結果に基づいて適切な処理パスにルーティングします。
10from typing import Any, Dict, List, Optional, Union, Callable, Literal
11from pydantic import BaseModel, Field, field_validator
12from abc import ABC, abstractmethod
13import logging
15from .flow.step import Step
16from .flow.context import Context
17from .pipeline.llm_pipeline import LLMPipeline, create_simple_llm_pipeline
19logger = logging.getLogger(__name__)
22class RouteClassifier(ABC):
23 """
24 Abstract base class for route classification logic.
25 ルート分類ロジックの抽象基底クラス。
26 """
28 @abstractmethod
29 def classify(self, input_data: Any, context: Context) -> str:
30 """
31 Classify input data and return the route key.
32 入力データを分類してルートキーを返します。
34 Args:
35 input_data: The input data to classify / 分類する入力データ
36 context: The execution context / 実行コンテキスト
38 Returns:
39 str: The route key for the classified input / 分類された入力のルートキー
40 """
41 pass
44class LLMClassifier(RouteClassifier):
45 """
46 LLM-based classifier for route determination.
47 ルート決定のためのLLMベース分類器。
48 """
50 def __init__(
51 self,
52 pipeline: LLMPipeline,
53 classification_prompt: str,
54 routes: List[str],
55 examples: Optional[Dict[str, List[str]]] = None
56 ):
57 """
58 Initialize LLM classifier.
59 LLM分類器を初期化します。
61 Args:
62 pipeline: LLM pipeline for classification / 分類用のLLMパイプライン
63 classification_prompt: Prompt template for classification / 分類用のプロンプトテンプレート
64 routes: List of possible route keys / 可能なルートキーのリスト
65 examples: Optional examples for each route / 各ルートのオプション例
66 """
67 self.pipeline = pipeline
68 self.classification_prompt = classification_prompt
69 self.routes = routes
70 self.examples = examples or {}
72 def classify(self, input_data: Any, context: Context) -> str:
73 """
74 Classify input using LLM.
75 LLMを使用して入力を分類します。
76 """
77 # Build classification prompt with examples
78 # 例を含む分類プロンプトを構築
79 examples_text = ""
80 if self.examples:
81 examples_text = "\n\nExamples:\n"
82 for route, example_list in self.examples.items():
83 examples_text += f"Route '{route}':\n"
84 for example in example_list:
85 examples_text += f"- {example}\n"
87 routes_text = ", ".join(self.routes)
89 full_prompt = f"""
90{self.classification_prompt}
92Available routes: {routes_text}
93{examples_text}
95Input to classify: {input_data}
97Respond with only the route key (one of: {routes_text})
98"""
100 try:
101 result = self.pipeline.run(full_prompt, context)
103 # Clean and validate the result
104 # 結果をクリーンアップして検証
105 classified_route = result.strip().lower()
107 # Find matching route (case insensitive)
108 # 一致するルートを検索(大文字小文字を区別しない)
109 for route in self.routes:
110 if route.lower() == classified_route:
111 return route
113 # If no exact match, try partial matching
114 # 完全一致しない場合、部分一致を試行
115 for route in self.routes:
116 if route.lower() in classified_route or classified_route in route.lower():
117 logger.warning(f"Partial route match: '{classified_route}' -> '{route}'")
118 return route
120 # Return None to let RouterAgent handle fallback
121 # RouterAgentにフォールバックを処理させるためNoneを返す
122 logger.warning(f"Could not classify input, returning None for RouterAgent fallback")
123 return None
125 except Exception as e:
126 logger.error(f"Classification error: {e}")
127 return None # Let RouterAgent handle fallback / RouterAgentにフォールバックを処理させる
130class RuleBasedClassifier(RouteClassifier):
131 """
132 Rule-based classifier using callable functions.
133 呼び出し可能な関数を使用したルールベース分類器。
134 """
136 def __init__(self, rules: Dict[str, Callable[[Any, Context], bool]]):
137 """
138 Initialize rule-based classifier.
139 ルールベース分類器を初期化します。
141 Args:
142 rules: Dictionary mapping route keys to classification functions
143 ルートキーを分類関数にマッピングする辞書
144 """
145 self.rules = rules
147 def classify(self, input_data: Any, context: Context) -> str:
148 """
149 Classify input using rules.
150 ルールを使用して入力を分類します。
151 """
152 for route_key, rule_func in self.rules.items():
153 try:
154 if rule_func(input_data, context):
155 return route_key
156 except Exception as e:
157 logger.warning(f"Rule evaluation error for route '{route_key}': {e}")
158 continue
160 # If no rules match, return the first route as fallback
161 # ルールに一致しない場合、最初のルートをフォールバックとして返す
162 fallback_route = next(iter(self.rules.keys()))
163 logger.warning(f"No rules matched, using fallback route: {fallback_route}")
164 return fallback_route
167class RouterConfig(BaseModel):
168 """
169 Configuration for RouterAgent.
170 RouterAgentの設定。
171 """
173 name: str = Field(description="Name of the router agent / ルーターエージェントの名前")
175 routes: Dict[str, str] = Field(
176 description="Mapping of route keys to next step names / ルートキーから次のステップ名へのマッピング"
177 )
179 classifier_type: Literal["llm", "rule"] = Field(
180 default="llm",
181 description="Type of classifier to use / 使用する分類器のタイプ"
182 )
184 # LLM classifier options
185 classification_prompt: Optional[str] = Field(
186 default=None,
187 description="Prompt for LLM classification / LLM分類用のプロンプト"
188 )
190 classification_examples: Optional[Dict[str, List[str]]] = Field(
191 default=None,
192 description="Examples for each route / 各ルートの例"
193 )
195 # Rule-based classifier options
196 classification_rules: Optional[Dict[str, Callable[[Any, Context], bool]]] = Field(
197 default=None,
198 description="Rules for classification / 分類のためのルール"
199 )
201 # Fallback options
202 default_route: Optional[str] = Field(
203 default=None,
204 description="Default route if classification fails / 分類が失敗した場合のデフォルトルート"
205 )
207 store_classification_result: bool = Field(
208 default=True,
209 description="Whether to store classification result in context / 分類結果をコンテキストに保存するかどうか"
210 )
212 @field_validator("routes")
213 @classmethod
214 def routes_not_empty(cls, v):
215 """Validate that routes are not empty / ルートが空でないことを検証"""
216 if not v:
217 raise ValueError("Routes cannot be empty / ルートは空にできません")
218 return v
220 @field_validator("default_route")
221 @classmethod
222 def default_route_exists(cls, v, info):
223 """Validate that default route exists in routes / デフォルトルートがルートに存在することを検証"""
224 if v is not None and "routes" in info.data and v not in info.data["routes"]:
225 raise ValueError(f"Default route '{v}' must exist in routes / デフォルトルート'{v}'はルートに存在する必要があります")
226 return v
229class RouterAgent(Step):
230 """
231 Router agent that classifies input and routes to appropriate next steps.
232 入力を分類して適切な次のステップにルーティングするルーターエージェント。
234 The RouterAgent analyzes input data using either LLM-based or rule-based
235 classification and determines which processing path the input should follow.
236 RouterAgentはLLMベースまたはルールベースの分類を使用して入力データを分析し、
237 入力がどの処理パスに従うべきかを決定します。
238 """
240 def __init__(self, config: RouterConfig, llm_pipeline: Optional[LLMPipeline] = None):
241 """
242 Initialize RouterAgent.
243 RouterAgentを初期化します。
245 Args:
246 config: Router configuration / ルーター設定
247 llm_pipeline: Optional LLM pipeline for LLM-based classification
248 LLMベース分類用のオプションのLLMパイプライン
249 """
250 super().__init__(name=config.name)
251 self.config = config
253 # Initialize classifier based on type
254 # タイプに基づいて分類器を初期化
255 if config.classifier_type == "llm":
256 if llm_pipeline is None:
257 # Create default LLM pipeline if none provided
258 # 提供されていない場合はデフォルトのLLMパイプラインを作成
259 llm_pipeline = create_simple_llm_pipeline(
260 name="router_classifier",
261 instructions="You are a classification assistant. Classify the input text into the provided categories."
262 )
264 # Use provided prompt or create default
265 # 提供されたプロンプトを使用するか、デフォルトを作成
266 prompt = config.classification_prompt or self._create_default_classification_prompt()
268 self.classifier = LLMClassifier(
269 pipeline=llm_pipeline,
270 classification_prompt=prompt,
271 routes=list(config.routes.keys()),
272 examples=config.classification_examples
273 )
275 elif config.classifier_type == "rule":
276 if config.classification_rules is None:
277 raise ValueError("classification_rules must be provided for rule-based classifier")
279 self.classifier = RuleBasedClassifier(config.classification_rules)
281 else:
282 raise ValueError(f"Unsupported classifier type: {config.classifier_type}")
284 def _create_default_classification_prompt(self) -> str:
285 """
286 Create default classification prompt.
287 デフォルトの分類プロンプトを作成します。
288 """
289 return f"""
290Classify the given input into one of the available routes.
291Consider the context and content of the input to determine the most appropriate route.
293Available routes: {', '.join(self.config.routes.keys())}
295Choose the route that best matches the input's intent, content, or characteristics.
296"""
298 async def run(self, user_input: Optional[str], ctx: Context) -> Context:
299 """
300 Execute the routing logic.
301 ルーティングロジックを実行します。
303 Args:
304 user_input: User input to classify and route / 分類・ルーティングするユーザー入力
305 ctx: Execution context / 実行コンテキスト
307 Returns:
308 Context: Updated context with routing information / ルーティング情報を含む更新されたコンテキスト
309 """
310 # Update step info
311 # ステップ情報を更新
312 ctx.update_step_info(self.name)
314 try:
315 # Use user_input as input_data for classification
316 # 分類にuser_inputをinput_dataとして使用
317 input_data = user_input or ctx.get_user_input() or ""
319 # Classify the input
320 # 入力を分類
321 route_key = self.classifier.classify(input_data, ctx)
323 # Check if classification failed (returned None)
324 # 分類が失敗したかチェック(Noneが返された)
325 classification_failed = route_key is None
327 # Validate route exists
328 # ルートが存在することを検証
329 if route_key is None or route_key not in self.config.routes:
330 logger.warning(f"Invalid route key '{route_key}', using default")
331 route_key = self.config.default_route or next(iter(self.config.routes.keys()))
333 # Store classification result in context if requested
334 # 要求された場合は分類結果をコンテキストに保存
335 if self.config.store_classification_result:
336 ctx.shared_state[f"{self.name}_classification"] = route_key
337 ctx.shared_state[f"{self.name}_next_step"] = self.config.routes[route_key]
339 # Store error info if classification failed
340 # 分類が失敗した場合はエラー情報を保存
341 if classification_failed:
342 ctx.shared_state[f"{self.name}_error"] = "Classification failed"
344 # Set the next step for flow control
345 # フロー制御用に次のステップを設定
346 next_step_name = self.config.routes[route_key]
347 ctx.goto(next_step_name)
349 logger.info(f"RouterAgent '{self.name}' classified input as '{route_key}' -> '{next_step_name}'")
351 return ctx
353 except Exception as e:
354 logger.error(f"RouterAgent '{self.name}' error: {e}")
356 # Use fallback route
357 # フォールバックルートを使用
358 fallback_route = self.config.default_route or next(iter(self.config.routes.keys()))
359 fallback_step = self.config.routes[fallback_route]
361 ctx.goto(fallback_step)
363 if self.config.store_classification_result:
364 ctx.shared_state[f"{self.name}_classification"] = fallback_route
365 ctx.shared_state[f"{self.name}_next_step"] = fallback_step
366 ctx.shared_state[f"{self.name}_error"] = str(e)
368 logger.info(f"RouterAgent '{self.name}' using fallback route '{fallback_route}' -> '{fallback_step}'")
370 return ctx
373# Utility functions for creating common router configurations
374# 一般的なルーター設定を作成するためのユーティリティ関数
376def create_intent_router(
377 name: str = "intent_router",
378 intents: Dict[str, str] = None,
379 llm_pipeline: Optional[LLMPipeline] = None
380) -> RouterAgent:
381 """
382 Create a router for intent-based routing.
383 意図ベースのルーティング用のルーターを作成します。
385 Args:
386 name: Name of the router / ルーターの名前
387 intents: Mapping of intent names to step names / 意図名からステップ名へのマッピング
388 llm_pipeline: Optional LLM pipeline / オプションのLLMパイプライン
390 Returns:
391 RouterAgent: Configured intent router / 設定された意図ルーター
392 """
393 if intents is None:
394 intents = {
395 "question": "qa_step",
396 "request": "service_step",
397 "complaint": "support_step",
398 "other": "general_step"
399 }
401 config = RouterConfig(
402 name=name,
403 routes=intents,
404 classifier_type="llm",
405 classification_prompt="""
406Analyze the user input and classify it based on the user's intent.
408Intents:
409- question: User is asking for information or clarification
410- request: User is requesting a service or action
411- complaint: User is expressing dissatisfaction or reporting a problem
412- other: Input doesn't fit the above categories
414Consider the tone, content, and context of the input to determine the intent.
415""",
416 classification_examples={
417 "question": [
418 "How does this work?",
419 "What is the difference between X and Y?",
420 "Can you explain this feature?"
421 ],
422 "request": [
423 "Please update my account",
424 "I need to change my password",
425 "Can you help me set this up?"
426 ],
427 "complaint": [
428 "This is not working properly",
429 "I'm having issues with the service",
430 "This is frustrating and needs to be fixed"
431 ]
432 }
433 )
435 return RouterAgent(config, llm_pipeline)
438def create_content_type_router(
439 name: str = "content_router",
440 content_types: Dict[str, str] = None,
441 llm_pipeline: Optional[LLMPipeline] = None
442) -> RouterAgent:
443 """
444 Create a router for content type-based routing.
445 コンテンツタイプベースのルーティング用のルーターを作成します。
447 Args:
448 name: Name of the router / ルーターの名前
449 content_types: Mapping of content types to step names / コンテンツタイプからステップ名へのマッピング
450 llm_pipeline: Optional LLM pipeline / オプションのLLMパイプライン
452 Returns:
453 RouterAgent: Configured content type router / 設定されたコンテンツタイプルーター
454 """
455 if content_types is None:
456 content_types = {
457 "document": "document_processor",
458 "image": "image_processor",
459 "code": "code_processor",
460 "data": "data_processor"
461 }
463 config = RouterConfig(
464 name=name,
465 routes=content_types,
466 classifier_type="llm",
467 classification_prompt="""
468Analyze the input and classify it based on content type.
470Content types:
471- document: Text documents, articles, reports, letters
472- image: Images, photos, diagrams, charts
473- code: Programming code, scripts, configuration files
474- data: Structured data, databases, spreadsheets, JSON/XML
476Consider the format, structure, and content characteristics.
477""",
478 classification_examples={
479 "document": [
480 "This is a business report about quarterly results...",
481 "Dear Sir/Madam, I am writing to inform you...",
482 "Executive Summary: The following document outlines..."
483 ],
484 "code": [
485 "def hello_world():\n print('Hello, World!')",
486 "SELECT * FROM users WHERE age > 18;",
487 "{\n \"name\": \"config\",\n \"version\": \"1.0\"\n}"
488 ]
489 }
490 )
492 return RouterAgent(config, llm_pipeline)