Coverage for src/refinire/agents/router.py: 90%

133 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 15:27 +0900

1""" 

2RouterAgent implementation for routing inputs based on classification. 

3 

4The RouterAgent analyzes input data and routes it to appropriate processing paths 

5based on configurable routing logic and classification results. 

6""" 

7# RouterAgent implementation for routing inputs based on classification. 

8# RouterAgentは入力データを分析し、設定可能なルーティングロジックと分類結果に基づいて適切な処理パスにルーティングします。 

9 

10from typing import Any, Dict, List, Optional, Union, Callable, Literal 

11from pydantic import BaseModel, Field, field_validator 

12from abc import ABC, abstractmethod 

13import logging 

14 

15from .flow.step import Step 

16from .flow.context import Context 

17from .pipeline.llm_pipeline import LLMPipeline, create_simple_llm_pipeline 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22class RouteClassifier(ABC): 

23 """ 

24 Abstract base class for route classification logic. 

25 ルート分類ロジックの抽象基底クラス。 

26 """ 

27 

28 @abstractmethod 

29 def classify(self, input_data: Any, context: Context) -> str: 

30 """ 

31 Classify input data and return the route key. 

32 入力データを分類してルートキーを返します。 

33  

34 Args: 

35 input_data: The input data to classify / 分類する入力データ 

36 context: The execution context / 実行コンテキスト 

37  

38 Returns: 

39 str: The route key for the classified input / 分類された入力のルートキー 

40 """ 

41 pass 

42 

43 

44class LLMClassifier(RouteClassifier): 

45 """ 

46 LLM-based classifier for route determination. 

47 ルート決定のためのLLMベース分類器。 

48 """ 

49 

50 def __init__( 

51 self, 

52 pipeline: LLMPipeline, 

53 classification_prompt: str, 

54 routes: List[str], 

55 examples: Optional[Dict[str, List[str]]] = None 

56 ): 

57 """ 

58 Initialize LLM classifier. 

59 LLM分類器を初期化します。 

60  

61 Args: 

62 pipeline: LLM pipeline for classification / 分類用のLLMパイプライン 

63 classification_prompt: Prompt template for classification / 分類用のプロンプトテンプレート 

64 routes: List of possible route keys / 可能なルートキーのリスト 

65 examples: Optional examples for each route / 各ルートのオプション例 

66 """ 

67 self.pipeline = pipeline 

68 self.classification_prompt = classification_prompt 

69 self.routes = routes 

70 self.examples = examples or {} 

71 

72 def classify(self, input_data: Any, context: Context) -> str: 

73 """ 

74 Classify input using LLM. 

75 LLMを使用して入力を分類します。 

76 """ 

77 # Build classification prompt with examples 

78 # 例を含む分類プロンプトを構築 

79 examples_text = "" 

80 if self.examples: 

81 examples_text = "\n\nExamples:\n" 

82 for route, example_list in self.examples.items(): 

83 examples_text += f"Route '{route}':\n" 

84 for example in example_list: 

85 examples_text += f"- {example}\n" 

86 

87 routes_text = ", ".join(self.routes) 

88 

89 full_prompt = f""" 

90{self.classification_prompt} 

91 

92Available routes: {routes_text} 

93{examples_text} 

94 

95Input to classify: {input_data} 

96 

97Respond with only the route key (one of: {routes_text}) 

98""" 

99 

100 try: 

101 result = self.pipeline.run(full_prompt, context) 

102 

103 # Clean and validate the result 

104 # 結果をクリーンアップして検証 

105 classified_route = result.strip().lower() 

106 

107 # Find matching route (case insensitive) 

108 # 一致するルートを検索(大文字小文字を区別しない) 

109 for route in self.routes: 

110 if route.lower() == classified_route: 

111 return route 

112 

113 # If no exact match, try partial matching 

114 # 完全一致しない場合、部分一致を試行 

115 for route in self.routes: 

116 if route.lower() in classified_route or classified_route in route.lower(): 

117 logger.warning(f"Partial route match: '{classified_route}' -> '{route}'") 

118 return route 

119 

120 # Return None to let RouterAgent handle fallback 

121 # RouterAgentにフォールバックを処理させるためNoneを返す 

122 logger.warning(f"Could not classify input, returning None for RouterAgent fallback") 

123 return None 

124 

125 except Exception as e: 

126 logger.error(f"Classification error: {e}") 

127 return None # Let RouterAgent handle fallback / RouterAgentにフォールバックを処理させる 

128 

129 

130class RuleBasedClassifier(RouteClassifier): 

131 """ 

132 Rule-based classifier using callable functions. 

133 呼び出し可能な関数を使用したルールベース分類器。 

134 """ 

135 

136 def __init__(self, rules: Dict[str, Callable[[Any, Context], bool]]): 

137 """ 

138 Initialize rule-based classifier. 

139 ルールベース分類器を初期化します。 

140  

141 Args: 

142 rules: Dictionary mapping route keys to classification functions 

143 ルートキーを分類関数にマッピングする辞書 

144 """ 

145 self.rules = rules 

146 

147 def classify(self, input_data: Any, context: Context) -> str: 

148 """ 

149 Classify input using rules. 

150 ルールを使用して入力を分類します。 

151 """ 

152 for route_key, rule_func in self.rules.items(): 

153 try: 

154 if rule_func(input_data, context): 

155 return route_key 

156 except Exception as e: 

157 logger.warning(f"Rule evaluation error for route '{route_key}': {e}") 

158 continue 

159 

160 # If no rules match, return the first route as fallback 

161 # ルールに一致しない場合、最初のルートをフォールバックとして返す 

162 fallback_route = next(iter(self.rules.keys())) 

163 logger.warning(f"No rules matched, using fallback route: {fallback_route}") 

164 return fallback_route 

165 

166 

167class RouterConfig(BaseModel): 

168 """ 

169 Configuration for RouterAgent. 

170 RouterAgentの設定。 

171 """ 

172 

173 name: str = Field(description="Name of the router agent / ルーターエージェントの名前") 

174 

175 routes: Dict[str, str] = Field( 

176 description="Mapping of route keys to next step names / ルートキーから次のステップ名へのマッピング" 

177 ) 

178 

179 classifier_type: Literal["llm", "rule"] = Field( 

180 default="llm", 

181 description="Type of classifier to use / 使用する分類器のタイプ" 

182 ) 

183 

184 # LLM classifier options 

185 classification_prompt: Optional[str] = Field( 

186 default=None, 

187 description="Prompt for LLM classification / LLM分類用のプロンプト" 

188 ) 

189 

190 classification_examples: Optional[Dict[str, List[str]]] = Field( 

191 default=None, 

192 description="Examples for each route / 各ルートの例" 

193 ) 

194 

195 # Rule-based classifier options 

196 classification_rules: Optional[Dict[str, Callable[[Any, Context], bool]]] = Field( 

197 default=None, 

198 description="Rules for classification / 分類のためのルール" 

199 ) 

200 

201 # Fallback options 

202 default_route: Optional[str] = Field( 

203 default=None, 

204 description="Default route if classification fails / 分類が失敗した場合のデフォルトルート" 

205 ) 

206 

207 store_classification_result: bool = Field( 

208 default=True, 

209 description="Whether to store classification result in context / 分類結果をコンテキストに保存するかどうか" 

210 ) 

211 

212 @field_validator("routes") 

213 @classmethod 

214 def routes_not_empty(cls, v): 

215 """Validate that routes are not empty / ルートが空でないことを検証""" 

216 if not v: 

217 raise ValueError("Routes cannot be empty / ルートは空にできません") 

218 return v 

219 

220 @field_validator("default_route") 

221 @classmethod 

222 def default_route_exists(cls, v, info): 

223 """Validate that default route exists in routes / デフォルトルートがルートに存在することを検証""" 

224 if v is not None and "routes" in info.data and v not in info.data["routes"]: 

225 raise ValueError(f"Default route '{v}' must exist in routes / デフォルトルート'{v}'はルートに存在する必要があります") 

226 return v 

227 

228 

229class RouterAgent(Step): 

230 """ 

231 Router agent that classifies input and routes to appropriate next steps. 

232 入力を分類して適切な次のステップにルーティングするルーターエージェント。 

233  

234 The RouterAgent analyzes input data using either LLM-based or rule-based 

235 classification and determines which processing path the input should follow. 

236 RouterAgentはLLMベースまたはルールベースの分類を使用して入力データを分析し、 

237 入力がどの処理パスに従うべきかを決定します。 

238 """ 

239 

240 def __init__(self, config: RouterConfig, llm_pipeline: Optional[LLMPipeline] = None): 

241 """ 

242 Initialize RouterAgent. 

243 RouterAgentを初期化します。 

244  

245 Args: 

246 config: Router configuration / ルーター設定 

247 llm_pipeline: Optional LLM pipeline for LLM-based classification 

248 LLMベース分類用のオプションのLLMパイプライン 

249 """ 

250 super().__init__(name=config.name) 

251 self.config = config 

252 

253 # Initialize classifier based on type 

254 # タイプに基づいて分類器を初期化 

255 if config.classifier_type == "llm": 

256 if llm_pipeline is None: 

257 # Create default LLM pipeline if none provided 

258 # 提供されていない場合はデフォルトのLLMパイプラインを作成 

259 llm_pipeline = create_simple_llm_pipeline( 

260 name="router_classifier", 

261 instructions="You are a classification assistant. Classify the input text into the provided categories." 

262 ) 

263 

264 # Use provided prompt or create default 

265 # 提供されたプロンプトを使用するか、デフォルトを作成 

266 prompt = config.classification_prompt or self._create_default_classification_prompt() 

267 

268 self.classifier = LLMClassifier( 

269 pipeline=llm_pipeline, 

270 classification_prompt=prompt, 

271 routes=list(config.routes.keys()), 

272 examples=config.classification_examples 

273 ) 

274 

275 elif config.classifier_type == "rule": 

276 if config.classification_rules is None: 

277 raise ValueError("classification_rules must be provided for rule-based classifier") 

278 

279 self.classifier = RuleBasedClassifier(config.classification_rules) 

280 

281 else: 

282 raise ValueError(f"Unsupported classifier type: {config.classifier_type}") 

283 

284 def _create_default_classification_prompt(self) -> str: 

285 """ 

286 Create default classification prompt. 

287 デフォルトの分類プロンプトを作成します。 

288 """ 

289 return f""" 

290Classify the given input into one of the available routes. 

291Consider the context and content of the input to determine the most appropriate route. 

292 

293Available routes: {', '.join(self.config.routes.keys())} 

294 

295Choose the route that best matches the input's intent, content, or characteristics. 

296""" 

297 

298 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

299 """ 

300 Execute the routing logic. 

301 ルーティングロジックを実行します。 

302  

303 Args: 

304 user_input: User input to classify and route / 分類・ルーティングするユーザー入力 

305 ctx: Execution context / 実行コンテキスト 

306  

307 Returns: 

308 Context: Updated context with routing information / ルーティング情報を含む更新されたコンテキスト 

309 """ 

310 # Update step info 

311 # ステップ情報を更新 

312 ctx.update_step_info(self.name) 

313 

314 try: 

315 # Use user_input as input_data for classification 

316 # 分類にuser_inputをinput_dataとして使用 

317 input_data = user_input or ctx.get_user_input() or "" 

318 

319 # Classify the input 

320 # 入力を分類 

321 route_key = self.classifier.classify(input_data, ctx) 

322 

323 # Check if classification failed (returned None) 

324 # 分類が失敗したかチェック(Noneが返された) 

325 classification_failed = route_key is None 

326 

327 # Validate route exists 

328 # ルートが存在することを検証 

329 if route_key is None or route_key not in self.config.routes: 

330 logger.warning(f"Invalid route key '{route_key}', using default") 

331 route_key = self.config.default_route or next(iter(self.config.routes.keys())) 

332 

333 # Store classification result in context if requested 

334 # 要求された場合は分類結果をコンテキストに保存 

335 if self.config.store_classification_result: 

336 ctx.shared_state[f"{self.name}_classification"] = route_key 

337 ctx.shared_state[f"{self.name}_next_step"] = self.config.routes[route_key] 

338 

339 # Store error info if classification failed 

340 # 分類が失敗した場合はエラー情報を保存 

341 if classification_failed: 

342 ctx.shared_state[f"{self.name}_error"] = "Classification failed" 

343 

344 # Set the next step for flow control 

345 # フロー制御用に次のステップを設定 

346 next_step_name = self.config.routes[route_key] 

347 ctx.goto(next_step_name) 

348 

349 logger.info(f"RouterAgent '{self.name}' classified input as '{route_key}' -> '{next_step_name}'") 

350 

351 return ctx 

352 

353 except Exception as e: 

354 logger.error(f"RouterAgent '{self.name}' error: {e}") 

355 

356 # Use fallback route 

357 # フォールバックルートを使用 

358 fallback_route = self.config.default_route or next(iter(self.config.routes.keys())) 

359 fallback_step = self.config.routes[fallback_route] 

360 

361 ctx.goto(fallback_step) 

362 

363 if self.config.store_classification_result: 

364 ctx.shared_state[f"{self.name}_classification"] = fallback_route 

365 ctx.shared_state[f"{self.name}_next_step"] = fallback_step 

366 ctx.shared_state[f"{self.name}_error"] = str(e) 

367 

368 logger.info(f"RouterAgent '{self.name}' using fallback route '{fallback_route}' -> '{fallback_step}'") 

369 

370 return ctx 

371 

372 

373# Utility functions for creating common router configurations 

374# 一般的なルーター設定を作成するためのユーティリティ関数 

375 

376def create_intent_router( 

377 name: str = "intent_router", 

378 intents: Dict[str, str] = None, 

379 llm_pipeline: Optional[LLMPipeline] = None 

380) -> RouterAgent: 

381 """ 

382 Create a router for intent-based routing. 

383 意図ベースのルーティング用のルーターを作成します。 

384  

385 Args: 

386 name: Name of the router / ルーターの名前 

387 intents: Mapping of intent names to step names / 意図名からステップ名へのマッピング 

388 llm_pipeline: Optional LLM pipeline / オプションのLLMパイプライン 

389  

390 Returns: 

391 RouterAgent: Configured intent router / 設定された意図ルーター 

392 """ 

393 if intents is None: 

394 intents = { 

395 "question": "qa_step", 

396 "request": "service_step", 

397 "complaint": "support_step", 

398 "other": "general_step" 

399 } 

400 

401 config = RouterConfig( 

402 name=name, 

403 routes=intents, 

404 classifier_type="llm", 

405 classification_prompt=""" 

406Analyze the user input and classify it based on the user's intent. 

407 

408Intents: 

409- question: User is asking for information or clarification 

410- request: User is requesting a service or action 

411- complaint: User is expressing dissatisfaction or reporting a problem 

412- other: Input doesn't fit the above categories 

413 

414Consider the tone, content, and context of the input to determine the intent. 

415""", 

416 classification_examples={ 

417 "question": [ 

418 "How does this work?", 

419 "What is the difference between X and Y?", 

420 "Can you explain this feature?" 

421 ], 

422 "request": [ 

423 "Please update my account", 

424 "I need to change my password", 

425 "Can you help me set this up?" 

426 ], 

427 "complaint": [ 

428 "This is not working properly", 

429 "I'm having issues with the service", 

430 "This is frustrating and needs to be fixed" 

431 ] 

432 } 

433 ) 

434 

435 return RouterAgent(config, llm_pipeline) 

436 

437 

438def create_content_type_router( 

439 name: str = "content_router", 

440 content_types: Dict[str, str] = None, 

441 llm_pipeline: Optional[LLMPipeline] = None 

442) -> RouterAgent: 

443 """ 

444 Create a router for content type-based routing. 

445 コンテンツタイプベースのルーティング用のルーターを作成します。 

446  

447 Args: 

448 name: Name of the router / ルーターの名前 

449 content_types: Mapping of content types to step names / コンテンツタイプからステップ名へのマッピング 

450 llm_pipeline: Optional LLM pipeline / オプションのLLMパイプライン 

451  

452 Returns: 

453 RouterAgent: Configured content type router / 設定されたコンテンツタイプルーター 

454 """ 

455 if content_types is None: 

456 content_types = { 

457 "document": "document_processor", 

458 "image": "image_processor", 

459 "code": "code_processor", 

460 "data": "data_processor" 

461 } 

462 

463 config = RouterConfig( 

464 name=name, 

465 routes=content_types, 

466 classifier_type="llm", 

467 classification_prompt=""" 

468Analyze the input and classify it based on content type. 

469 

470Content types: 

471- document: Text documents, articles, reports, letters 

472- image: Images, photos, diagrams, charts 

473- code: Programming code, scripts, configuration files 

474- data: Structured data, databases, spreadsheets, JSON/XML 

475 

476Consider the format, structure, and content characteristics. 

477""", 

478 classification_examples={ 

479 "document": [ 

480 "This is a business report about quarterly results...", 

481 "Dear Sir/Madam, I am writing to inform you...", 

482 "Executive Summary: The following document outlines..." 

483 ], 

484 "code": [ 

485 "def hello_world():\n print('Hello, World!')", 

486 "SELECT * FROM users WHERE age > 18;", 

487 "{\n \"name\": \"config\",\n \"version\": \"1.0\"\n}" 

488 ] 

489 } 

490 ) 

491 

492 return RouterAgent(config, llm_pipeline)