Coverage for src/refinire/agents/router.py: 90%

133 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-15 18:51 +0900

1""" 

2RouterAgent implementation for routing inputs based on classification. 

3 

4The RouterAgent analyzes input data and routes it to appropriate processing paths 

5based on configurable routing logic and classification results. 

6""" 

7# RouterAgent implementation for routing inputs based on classification. 

8# RouterAgentは入力データを分析し、設定可能なルーティングロジックと分類結果に基づいて適切な処理パスにルーティングします。 

9 

10from typing import Any, Dict, List, Optional, Union, Callable, Literal 

11from pydantic import BaseModel, Field, field_validator 

12from abc import ABC, abstractmethod 

13import logging 

14 

15from .flow.step import Step 

16from .flow.context import Context 

17from .pipeline.llm_pipeline import LLMPipeline, create_simple_llm_pipeline 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22class RouteClassifier(ABC): 

23 """ 

24 Abstract base class for route classification logic. 

25 ルート分類ロジックの抽象基底クラス。 

26 """ 

27 

28 @abstractmethod 

29 def classify(self, input_data: Any, context: Context) -> str: 

30 """ 

31 Classify input data and return the route key. 

32 入力データを分類してルートキーを返します。 

33  

34 Args: 

35 input_data: The input data to classify / 分類する入力データ 

36 context: The execution context / 実行コンテキスト 

37  

38 Returns: 

39 str: The route key for the classified input / 分類された入力のルートキー 

40 """ 

41 pass 

42 

43 

44class LLMClassifier(RouteClassifier): 

45 """ 

46 LLM-based classifier for route determination. 

47 ルート決定のためのLLMベース分類器。 

48 """ 

49 

50 def __init__( 

51 self, 

52 pipeline: LLMPipeline, 

53 classification_prompt: str, 

54 routes: List[str], 

55 examples: Optional[Dict[str, List[str]]] = None 

56 ): 

57 """ 

58 Initialize LLM classifier. 

59 LLM分類器を初期化します。 

60  

61 Args: 

62 pipeline: LLM pipeline for classification / 分類用のLLMパイプライン 

63 classification_prompt: Prompt template for classification / 分類用のプロンプトテンプレート 

64 routes: List of possible route keys / 可能なルートキーのリスト 

65 examples: Optional examples for each route / 各ルートのオプション例 

66 """ 

67 self.pipeline = pipeline 

68 self.classification_prompt = classification_prompt 

69 self.routes = routes 

70 self.examples = examples or {} 

71 

72 def classify(self, input_data: Any, context: Context) -> str: 

73 """ 

74 Classify input using LLM. 

75 LLMを使用して入力を分類します。 

76 """ 

77 # Build classification prompt with examples 

78 # 例を含む分類プロンプトを構築 

79 examples_text = "" 

80 if self.examples: 

81 examples_text = "\n\nExamples:\n" 

82 for route, example_list in self.examples.items(): 

83 examples_text += f"Route '{route}':\n" 

84 for example in example_list: 

85 examples_text += f"- {example}\n" 

86 

87 routes_text = ", ".join(self.routes) 

88 

89 full_prompt = f""" 

90{self.classification_prompt} 

91 

92Available routes: {routes_text} 

93{examples_text} 

94 

95Input to classify: {input_data} 

96 

97Respond with only the route key (one of: {routes_text}) 

98""" 

99 

100 try: 

101 result = self.pipeline.run(full_prompt, context) 

102 

103 # Clean and validate the result 

104 # 結果をクリーンアップして検証 

105 classified_route = result.strip().lower() 

106 

107 # Find matching route (case insensitive) 

108 # 一致するルートを検索(大文字小文字を区別しない) 

109 for route in self.routes: 

110 if route.lower() == classified_route: 

111 return route 

112 

113 # If no exact match, try partial matching 

114 # 完全一致しない場合、部分一致を試行 

115 for route in self.routes: 

116 if route.lower() in classified_route or classified_route in route.lower(): 

117 logger.warning(f"Partial route match: '{classified_route}' -> '{route}'") 

118 return route 

119 

120 # Return None to let RouterAgent handle fallback 

121 # RouterAgentにフォールバックを処理させるためNoneを返す 

122 logger.warning(f"Could not classify input, returning None for RouterAgent fallback") 

123 return None 

124 

125 except Exception as e: 

126 logger.error(f"Classification error: {e}") 

127 return None # Let RouterAgent handle fallback / RouterAgentにフォールバックを処理させる 

128 

129 

130class RuleBasedClassifier(RouteClassifier): 

131 """ 

132 Rule-based classifier using callable functions. 

133 呼び出し可能な関数を使用したルールベース分類器。 

134 """ 

135 

136 def __init__(self, rules: Dict[str, Callable[[Any, Context], bool]]): 

137 """ 

138 Initialize rule-based classifier. 

139 ルールベース分類器を初期化します。 

140  

141 Args: 

142 rules: Dictionary mapping route keys to classification functions 

143 ルートキーを分類関数にマッピングする辞書 

144 """ 

145 self.rules = rules 

146 

147 def classify(self, input_data: Any, context: Context) -> str: 

148 """ 

149 Classify input using rules. 

150 ルールを使用して入力を分類します。 

151 """ 

152 for route_key, rule_func in self.rules.items(): 

153 try: 

154 if rule_func(input_data, context): 

155 return route_key 

156 except Exception as e: 

157 logger.warning(f"Rule evaluation error for route '{route_key}': {e}") 

158 continue 

159 

160 # If no rules match, return the first route as fallback 

161 # ルールに一致しない場合、最初のルートをフォールバックとして返す 

162 fallback_route = next(iter(self.rules.keys())) 

163 logger.warning(f"No rules matched, using fallback route: {fallback_route}") 

164 return fallback_route 

165 

166 

167class RouterConfig(BaseModel): 

168 """ 

169 Configuration for RouterAgent. 

170 RouterAgentの設定。 

171 """ 

172 

173 name: str = Field(description="Name of the router agent / ルーターエージェントの名前") 

174 

175 routes: Dict[str, str] = Field( 

176 description="Mapping of route keys to next step names / ルートキーから次のステップ名へのマッピング" 

177 ) 

178 

179 classifier_type: Literal["llm", "rule"] = Field( 

180 default="llm", 

181 description="Type of classifier to use / 使用する分類器のタイプ" 

182 ) 

183 

184 # LLM classifier options 

185 classification_prompt: Optional[str] = Field( 

186 default=None, 

187 description="Prompt for LLM classification / LLM分類用のプロンプト" 

188 ) 

189 

190 classification_examples: Optional[Dict[str, List[str]]] = Field( 

191 default=None, 

192 description="Examples for each route / 各ルートの例" 

193 ) 

194 

195 # Rule-based classifier options 

196 classification_rules: Optional[Dict[str, Callable[[Any, Context], bool]]] = Field( 

197 default=None, 

198 description="Rules for classification / 分類のためのルール" 

199 ) 

200 

201 # Fallback options 

202 default_route: Optional[str] = Field( 

203 default=None, 

204 description="Default route if classification fails / 分類が失敗した場合のデフォルトルート" 

205 ) 

206 

207 store_classification_result: bool = Field( 

208 default=True, 

209 description="Whether to store classification result in context / 分類結果をコンテキストに保存するかどうか" 

210 ) 

211 

212 @field_validator("routes") 

213 @classmethod 

214 def routes_not_empty(cls, v): 

215 """Validate that routes are not empty / ルートが空でないことを検証""" 

216 if not v: 

217 raise ValueError("Routes cannot be empty / ルートは空にできません") 

218 return v 

219 

220 @field_validator("default_route") 

221 @classmethod 

222 def default_route_exists(cls, v, info): 

223 """Validate that default route exists in routes / デフォルトルートがルートに存在することを検証""" 

224 if v is not None and "routes" in info.data and v not in info.data["routes"]: 

225 raise ValueError(f"Default route '{v}' must exist in routes / デフォルトルート'{v}'はルートに存在する必要があります") 

226 return v 

227 

228 

229class RouterAgent(Step): 

230 """ 

231 Router agent that classifies input and routes to appropriate next steps. 

232 入力を分類して適切な次のステップにルーティングするルーターエージェント。 

233  

234 The RouterAgent analyzes input data using either LLM-based or rule-based 

235 classification and determines which processing path the input should follow. 

236 RouterAgentはLLMベースまたはルールベースの分類を使用して入力データを分析し、 

237 入力がどの処理パスに従うべきかを決定します。 

238 """ 

239 

240 def __init__(self, config: RouterConfig, llm_pipeline: Optional[LLMPipeline] = None): 

241 """ 

242 Initialize RouterAgent. 

243 RouterAgentを初期化します。 

244  

245 Args: 

246 config: Router configuration / ルーター設定 

247 llm_pipeline: Optional LLM pipeline for LLM-based classification 

248 LLMベース分類用のオプションのLLMパイプライン 

249 """ 

250 super().__init__(name=config.name) 

251 self.config = config 

252 

253 # Initialize classifier based on type 

254 # タイプに基づいて分類器を初期化 

255 if config.classifier_type == "llm": 

256 if llm_pipeline is None: 

257 # Create default LLM pipeline if none provided 

258 # 提供されていない場合はデフォルトのLLMパイプラインを作成 

259 llm_pipeline = create_simple_llm_pipeline() 

260 

261 # Use provided prompt or create default 

262 # 提供されたプロンプトを使用するか、デフォルトを作成 

263 prompt = config.classification_prompt or self._create_default_classification_prompt() 

264 

265 self.classifier = LLMClassifier( 

266 pipeline=llm_pipeline, 

267 classification_prompt=prompt, 

268 routes=list(config.routes.keys()), 

269 examples=config.classification_examples 

270 ) 

271 

272 elif config.classifier_type == "rule": 

273 if config.classification_rules is None: 

274 raise ValueError("classification_rules must be provided for rule-based classifier") 

275 

276 self.classifier = RuleBasedClassifier(config.classification_rules) 

277 

278 else: 

279 raise ValueError(f"Unsupported classifier type: {config.classifier_type}") 

280 

281 def _create_default_classification_prompt(self) -> str: 

282 """ 

283 Create default classification prompt. 

284 デフォルトの分類プロンプトを作成します。 

285 """ 

286 return f""" 

287Classify the given input into one of the available routes. 

288Consider the context and content of the input to determine the most appropriate route. 

289 

290Available routes: {', '.join(self.config.routes.keys())} 

291 

292Choose the route that best matches the input's intent, content, or characteristics. 

293""" 

294 

295 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

296 """ 

297 Execute the routing logic. 

298 ルーティングロジックを実行します。 

299  

300 Args: 

301 user_input: User input to classify and route / 分類・ルーティングするユーザー入力 

302 ctx: Execution context / 実行コンテキスト 

303  

304 Returns: 

305 Context: Updated context with routing information / ルーティング情報を含む更新されたコンテキスト 

306 """ 

307 # Update step info 

308 # ステップ情報を更新 

309 ctx.update_step_info(self.name) 

310 

311 try: 

312 # Use user_input as input_data for classification 

313 # 分類にuser_inputをinput_dataとして使用 

314 input_data = user_input or ctx.get_user_input() or "" 

315 

316 # Classify the input 

317 # 入力を分類 

318 route_key = self.classifier.classify(input_data, ctx) 

319 

320 # Check if classification failed (returned None) 

321 # 分類が失敗したかチェック(Noneが返された) 

322 classification_failed = route_key is None 

323 

324 # Validate route exists 

325 # ルートが存在することを検証 

326 if route_key is None or route_key not in self.config.routes: 

327 logger.warning(f"Invalid route key '{route_key}', using default") 

328 route_key = self.config.default_route or next(iter(self.config.routes.keys())) 

329 

330 # Store classification result in context if requested 

331 # 要求された場合は分類結果をコンテキストに保存 

332 if self.config.store_classification_result: 

333 ctx.shared_state[f"{self.name}_classification"] = route_key 

334 ctx.shared_state[f"{self.name}_next_step"] = self.config.routes[route_key] 

335 

336 # Store error info if classification failed 

337 # 分類が失敗した場合はエラー情報を保存 

338 if classification_failed: 

339 ctx.shared_state[f"{self.name}_error"] = "Classification failed" 

340 

341 # Set the next step for flow control 

342 # フロー制御用に次のステップを設定 

343 next_step_name = self.config.routes[route_key] 

344 ctx.goto(next_step_name) 

345 

346 logger.info(f"RouterAgent '{self.name}' classified input as '{route_key}' -> '{next_step_name}'") 

347 

348 return ctx 

349 

350 except Exception as e: 

351 logger.error(f"RouterAgent '{self.name}' error: {e}") 

352 

353 # Use fallback route 

354 # フォールバックルートを使用 

355 fallback_route = self.config.default_route or next(iter(self.config.routes.keys())) 

356 fallback_step = self.config.routes[fallback_route] 

357 

358 ctx.goto(fallback_step) 

359 

360 if self.config.store_classification_result: 

361 ctx.shared_state[f"{self.name}_classification"] = fallback_route 

362 ctx.shared_state[f"{self.name}_next_step"] = fallback_step 

363 ctx.shared_state[f"{self.name}_error"] = str(e) 

364 

365 logger.info(f"RouterAgent '{self.name}' using fallback route '{fallback_route}' -> '{fallback_step}'") 

366 

367 return ctx 

368 

369 

370# Utility functions for creating common router configurations 

371# 一般的なルーター設定を作成するためのユーティリティ関数 

372 

373def create_intent_router( 

374 name: str = "intent_router", 

375 intents: Dict[str, str] = None, 

376 llm_pipeline: Optional[LLMPipeline] = None 

377) -> RouterAgent: 

378 """ 

379 Create a router for intent-based routing. 

380 意図ベースのルーティング用のルーターを作成します。 

381  

382 Args: 

383 name: Name of the router / ルーターの名前 

384 intents: Mapping of intent names to step names / 意図名からステップ名へのマッピング 

385 llm_pipeline: Optional LLM pipeline / オプションのLLMパイプライン 

386  

387 Returns: 

388 RouterAgent: Configured intent router / 設定された意図ルーター 

389 """ 

390 if intents is None: 

391 intents = { 

392 "question": "qa_step", 

393 "request": "service_step", 

394 "complaint": "support_step", 

395 "other": "general_step" 

396 } 

397 

398 config = RouterConfig( 

399 name=name, 

400 routes=intents, 

401 classifier_type="llm", 

402 classification_prompt=""" 

403Analyze the user input and classify it based on the user's intent. 

404 

405Intents: 

406- question: User is asking for information or clarification 

407- request: User is requesting a service or action 

408- complaint: User is expressing dissatisfaction or reporting a problem 

409- other: Input doesn't fit the above categories 

410 

411Consider the tone, content, and context of the input to determine the intent. 

412""", 

413 classification_examples={ 

414 "question": [ 

415 "How does this work?", 

416 "What is the difference between X and Y?", 

417 "Can you explain this feature?" 

418 ], 

419 "request": [ 

420 "Please update my account", 

421 "I need to change my password", 

422 "Can you help me set this up?" 

423 ], 

424 "complaint": [ 

425 "This is not working properly", 

426 "I'm having issues with the service", 

427 "This is frustrating and needs to be fixed" 

428 ] 

429 } 

430 ) 

431 

432 return RouterAgent(config, llm_pipeline) 

433 

434 

435def create_content_type_router( 

436 name: str = "content_router", 

437 content_types: Dict[str, str] = None, 

438 llm_pipeline: Optional[LLMPipeline] = None 

439) -> RouterAgent: 

440 """ 

441 Create a router for content type-based routing. 

442 コンテンツタイプベースのルーティング用のルーターを作成します。 

443  

444 Args: 

445 name: Name of the router / ルーターの名前 

446 content_types: Mapping of content types to step names / コンテンツタイプからステップ名へのマッピング 

447 llm_pipeline: Optional LLM pipeline / オプションのLLMパイプライン 

448  

449 Returns: 

450 RouterAgent: Configured content type router / 設定されたコンテンツタイプルーター 

451 """ 

452 if content_types is None: 

453 content_types = { 

454 "document": "document_processor", 

455 "image": "image_processor", 

456 "code": "code_processor", 

457 "data": "data_processor" 

458 } 

459 

460 config = RouterConfig( 

461 name=name, 

462 routes=content_types, 

463 classifier_type="llm", 

464 classification_prompt=""" 

465Analyze the input and classify it based on content type. 

466 

467Content types: 

468- document: Text documents, articles, reports, letters 

469- image: Images, photos, diagrams, charts 

470- code: Programming code, scripts, configuration files 

471- data: Structured data, databases, spreadsheets, JSON/XML 

472 

473Consider the format, structure, and content characteristics. 

474""", 

475 classification_examples={ 

476 "document": [ 

477 "This is a business report about quarterly results...", 

478 "Dear Sir/Madam, I am writing to inform you...", 

479 "Executive Summary: The following document outlines..." 

480 ], 

481 "code": [ 

482 "def hello_world():\n print('Hello, World!')", 

483 "SELECT * FROM users WHERE age > 18;", 

484 "{\n \"name\": \"config\",\n \"version\": \"1.0\"\n}" 

485 ] 

486 } 

487 ) 

488 

489 return RouterAgent(config, llm_pipeline)