Coverage for src/refinire/agents/validator.py: 89%

215 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-15 18:51 +0900

1""" 

2ValidatorAgent implementation for data validation and business rule enforcement. 

3 

4ValidatorAgentはデータ検証とビジネスルール適用を行うエージェントです。 

5入力データの妥当性をチェックし、カスタム検証ルールを適用できます。 

6""" 

7 

8import logging 

9from abc import ABC, abstractmethod 

10from typing import Any, List, Optional, Dict, Union, Callable 

11from pydantic import BaseModel, Field, field_validator 

12from datetime import datetime 

13import re 

14 

15from .flow.context import Context 

16from .flow.step import Step 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21class ValidationRule(ABC): 

22 """ 

23 Abstract base class for validation rules. 

24 検証ルールの抽象基底クラス。 

25 """ 

26 

27 def __init__(self, name: str, description: str = ""): 

28 """ 

29 Initialize validation rule. 

30 検証ルールを初期化します。 

31  

32 Args: 

33 name: Rule name / ルール名 

34 description: Rule description / ルールの説明 

35 """ 

36 self.name = name 

37 self.description = description 

38 

39 @abstractmethod 

40 def validate(self, data: Any, context: Context) -> bool: 

41 """ 

42 Validate data against this rule. 

43 このルールに対してデータを検証します。 

44  

45 Args: 

46 data: Data to validate / 検証するデータ 

47 context: Execution context / 実行コンテキスト 

48  

49 Returns: 

50 bool: True if valid, False otherwise / 有効な場合True、そうでなければFalse 

51 """ 

52 pass 

53 

54 @abstractmethod 

55 def get_error_message(self, data: Any) -> str: 

56 """ 

57 Get error message for validation failure. 

58 検証失敗時のエラーメッセージを取得します。 

59  

60 Args: 

61 data: Failed data / 失敗したデータ 

62  

63 Returns: 

64 str: Error message / エラーメッセージ 

65 """ 

66 pass 

67 

68 

69class RequiredRule(ValidationRule): 

70 """ 

71 Rule to check if data is not None or empty. 

72 データがNoneまたは空でないことをチェックするルール。 

73 """ 

74 

75 def __init__(self, name: str = "required"): 

76 super().__init__(name, "Data must not be None or empty") 

77 

78 def validate(self, data: Any, context: Context) -> bool: 

79 """Validate that data is not None or empty.""" 

80 if data is None: 

81 return False 

82 if isinstance(data, str) and data.strip() == "": 

83 return False 

84 if isinstance(data, (list, dict)) and len(data) == 0: 

85 return False 

86 return True 

87 

88 def get_error_message(self, data: Any) -> str: 

89 return f"Required field cannot be empty" 

90 

91 

92class EmailFormatRule(ValidationRule): 

93 """ 

94 Rule to validate email format. 

95 メール形式を検証するルール。 

96 """ 

97 

98 def __init__(self, name: str = "email_format"): 

99 super().__init__(name, "Data must be a valid email format") 

100 self.email_pattern = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$') 

101 

102 def validate(self, data: Any, context: Context) -> bool: 

103 """Validate email format.""" 

104 if not isinstance(data, str): 

105 return False 

106 return bool(self.email_pattern.match(data)) 

107 

108 def get_error_message(self, data: Any) -> str: 

109 return f"'{data}' is not a valid email format" 

110 

111 

112class LengthRule(ValidationRule): 

113 """ 

114 Rule to validate data length. 

115 データ長を検証するルール。 

116 """ 

117 

118 def __init__(self, min_length: Optional[int] = None, max_length: Optional[int] = None, name: str = "length"): 

119 super().__init__(name, f"Data length must be between {min_length} and {max_length}") 

120 self.min_length = min_length 

121 self.max_length = max_length 

122 

123 def validate(self, data: Any, context: Context) -> bool: 

124 """Validate data length.""" 

125 if data is None: 

126 return False 

127 

128 length = len(str(data)) 

129 

130 if self.min_length is not None and length < self.min_length: 

131 return False 

132 if self.max_length is not None and length > self.max_length: 

133 return False 

134 

135 return True 

136 

137 def get_error_message(self, data: Any) -> str: 

138 length = len(str(data)) if data is not None else 0 

139 return f"Length {length} is not between {self.min_length} and {self.max_length}" 

140 

141 

142class RangeRule(ValidationRule): 

143 """ 

144 Rule to validate numeric range. 

145 数値範囲を検証するルール。 

146 """ 

147 

148 def __init__(self, min_value: Optional[Union[int, float]] = None, 

149 max_value: Optional[Union[int, float]] = None, name: str = "range"): 

150 super().__init__(name, f"Value must be between {min_value} and {max_value}") 

151 self.min_value = min_value 

152 self.max_value = max_value 

153 

154 def validate(self, data: Any, context: Context) -> bool: 

155 """Validate numeric range.""" 

156 try: 

157 value = float(data) 

158 

159 if self.min_value is not None and value < self.min_value: 

160 return False 

161 if self.max_value is not None and value > self.max_value: 

162 return False 

163 

164 return True 

165 except (ValueError, TypeError): 

166 return False 

167 

168 def get_error_message(self, data: Any) -> str: 

169 return f"Value '{data}' is not between {self.min_value} and {self.max_value}" 

170 

171 

172class RegexRule(ValidationRule): 

173 """ 

174 Rule to validate data against a regular expression. 

175 正規表現に対してデータを検証するルール。 

176 """ 

177 

178 def __init__(self, pattern: str, name: str = "regex"): 

179 super().__init__(name, f"Data must match pattern: {pattern}") 

180 self.pattern = re.compile(pattern) 

181 

182 def validate(self, data: Any, context: Context) -> bool: 

183 """Validate data against regex pattern.""" 

184 if not isinstance(data, str): 

185 return False 

186 return bool(self.pattern.match(data)) 

187 

188 def get_error_message(self, data: Any) -> str: 

189 return f"'{data}' does not match required pattern" 

190 

191 

192class CustomFunctionRule(ValidationRule): 

193 """ 

194 Rule using a custom validation function. 

195 カスタム検証関数を使用するルール。 

196 """ 

197 

198 def __init__(self, validation_func: Callable[[Any, Context], bool], 

199 error_message: str, name: str = "custom"): 

200 super().__init__(name, "Custom validation rule") 

201 self.validation_func = validation_func 

202 self.error_message = error_message 

203 

204 def validate(self, data: Any, context: Context) -> bool: 

205 """Validate using custom function.""" 

206 try: 

207 return self.validation_func(data, context) 

208 except Exception as e: 

209 logger.warning(f"Custom validation function error: {e}") 

210 return False 

211 

212 def get_error_message(self, data: Any) -> str: 

213 return self.error_message 

214 

215 

216class ValidationResult: 

217 """ 

218 Result of validation operation. 

219 検証操作の結果。 

220 """ 

221 

222 def __init__(self, is_valid: bool, errors: List[str] = None, warnings: List[str] = None): 

223 """ 

224 Initialize validation result. 

225 検証結果を初期化します。 

226  

227 Args: 

228 is_valid: Whether validation passed / 検証が通ったかどうか 

229 errors: List of error messages / エラーメッセージのリスト 

230 warnings: List of warning messages / 警告メッセージのリスト 

231 """ 

232 self.is_valid = is_valid 

233 self.errors = errors or [] 

234 self.warnings = warnings or [] 

235 self.timestamp = datetime.now() 

236 

237 def add_error(self, error: str): 

238 """Add an error message.""" 

239 self.errors.append(error) 

240 self.is_valid = False 

241 

242 def add_warning(self, warning: str): 

243 """Add a warning message.""" 

244 self.warnings.append(warning) 

245 

246 def __str__(self) -> str: 

247 status = "VALID" if self.is_valid else "INVALID" 

248 return f"ValidationResult({status}, {len(self.errors)} errors, {len(self.warnings)} warnings)" 

249 

250 

251class ValidatorConfig(BaseModel): 

252 """ 

253 Configuration for ValidatorAgent. 

254 ValidatorAgentの設定。 

255 """ 

256 

257 name: str = Field(description="Name of the validator agent / バリデーターエージェントの名前") 

258 

259 rules: List[Dict[str, Any]] = Field( 

260 default=[], 

261 description="List of validation rules / 検証ルールのリスト" 

262 ) 

263 

264 fail_fast: bool = Field( 

265 default=False, 

266 description="Stop validation on first error / 最初のエラーで検証を停止" 

267 ) 

268 

269 store_result: bool = Field( 

270 default=True, 

271 description="Store validation result in context / 検証結果をコンテキストに保存" 

272 ) 

273 

274 raise_on_error: bool = Field( 

275 default=False, 

276 description="Raise exception on validation failure / 検証失敗時に例外を発生" 

277 ) 

278 

279 @field_validator("rules") 

280 @classmethod 

281 def rules_not_empty(cls, v): 

282 """Validate that rules are provided.""" 

283 if not v: 

284 logger.warning("No validation rules provided") 

285 return v 

286 

287 

288class ValidatorAgent(Step): 

289 """ 

290 Validator agent for data validation and business rule enforcement. 

291 データ検証とビジネスルール適用を行うバリデーターエージェント。 

292  

293 The ValidatorAgent checks input data against configured validation rules 

294 and returns validation results with detailed error messages. 

295 ValidatorAgentは設定された検証ルールに対して入力データをチェックし、 

296 詳細なエラーメッセージと共に検証結果を返します。 

297 """ 

298 

299 def __init__(self, config: ValidatorConfig, custom_rules: List[ValidationRule] = None): 

300 """ 

301 Initialize ValidatorAgent. 

302 ValidatorAgentを初期化します。 

303  

304 Args: 

305 config: Validator configuration / バリデーター設定 

306 custom_rules: Optional custom validation rules / オプションのカスタム検証ルール 

307 """ 

308 super().__init__(name=config.name) 

309 self.config = config 

310 self.validation_rules = self._build_validation_rules(custom_rules or []) 

311 

312 def _build_validation_rules(self, custom_rules: List[ValidationRule]) -> List[ValidationRule]: 

313 """ 

314 Build validation rules from configuration and custom rules. 

315 設定とカスタムルールから検証ルールを構築します。 

316 """ 

317 rules = list(custom_rules) 

318 

319 # Build rules from configuration 

320 # 設定からルールを構築 

321 for rule_config in self.config.rules: 

322 rule_type = rule_config.get("type") 

323 rule_name = rule_config.get("name", rule_type) 

324 

325 if rule_type == "required": 

326 rules.append(RequiredRule(rule_name)) 

327 

328 elif rule_type == "email": 

329 rules.append(EmailFormatRule(rule_name)) 

330 

331 elif rule_type == "length": 

332 min_len = rule_config.get("min_length") 

333 max_len = rule_config.get("max_length") 

334 rules.append(LengthRule(min_len, max_len, rule_name)) 

335 

336 elif rule_type == "range": 

337 min_val = rule_config.get("min_value") 

338 max_val = rule_config.get("max_value") 

339 rules.append(RangeRule(min_val, max_val, rule_name)) 

340 

341 elif rule_type == "regex": 

342 pattern = rule_config.get("pattern") 

343 if pattern: 

344 rules.append(RegexRule(pattern, rule_name)) 

345 

346 else: 

347 logger.warning(f"Unknown rule type: {rule_type}") 

348 

349 return rules 

350 

351 async def run(self, user_input: Optional[str], ctx: Context) -> Context: 

352 """ 

353 Execute the validation logic. 

354 検証ロジックを実行します。 

355  

356 Args: 

357 user_input: User input to validate / 検証するユーザー入力 

358 ctx: Execution context / 実行コンテキスト 

359  

360 Returns: 

361 Context: Updated context with validation results / 検証結果を含む更新されたコンテキスト 

362 """ 

363 # Update step info 

364 # ステップ情報を更新 

365 ctx.update_step_info(self.name) 

366 

367 try: 

368 # Determine data to validate 

369 # 検証するデータを決定 

370 data_to_validate = user_input 

371 if data_to_validate is None: 

372 data_to_validate = ctx.get_user_input() 

373 

374 # Perform validation 

375 # 検証を実行 

376 validation_result = self._validate_data(data_to_validate, ctx) 

377 

378 # Store result in context if requested 

379 # 要求された場合は結果をコンテキストに保存 

380 if self.config.store_result: 

381 ctx.shared_state[f"{self.name}_result"] = { 

382 "is_valid": validation_result.is_valid, 

383 "errors": validation_result.errors, 

384 "warnings": validation_result.warnings, 

385 "timestamp": validation_result.timestamp.isoformat() 

386 } 

387 

388 # Handle validation failure 

389 # 検証失敗を処理 

390 if not validation_result.is_valid: 

391 error_summary = f"Validation failed: {', '.join(validation_result.errors)}" 

392 

393 if self.config.raise_on_error: 

394 raise ValueError(error_summary) 

395 

396 logger.warning(f"ValidatorAgent '{self.name}': {error_summary}") 

397 ctx.shared_state[f"{self.name}_status"] = "failed" 

398 else: 

399 logger.info(f"ValidatorAgent '{self.name}': Validation successful") 

400 ctx.shared_state[f"{self.name}_status"] = "success" 

401 

402 # Add warnings to context if any 

403 # 警告があればコンテキストに追加 

404 if validation_result.warnings: 

405 ctx.shared_state[f"{self.name}_warnings"] = validation_result.warnings 

406 

407 return ctx 

408 

409 except Exception as e: 

410 logger.error(f"ValidatorAgent '{self.name}' error: {e}") 

411 

412 if self.config.store_result: 

413 ctx.shared_state[f"{self.name}_result"] = { 

414 "is_valid": False, 

415 "errors": [str(e)], 

416 "warnings": [], 

417 "timestamp": datetime.now().isoformat() 

418 } 

419 ctx.shared_state[f"{self.name}_status"] = "error" 

420 

421 if self.config.raise_on_error: 

422 raise 

423 

424 return ctx 

425 

426 def _validate_data(self, data: Any, context: Context) -> ValidationResult: 

427 """ 

428 Validate data against all configured rules. 

429 設定された全てのルールに対してデータを検証します。 

430 """ 

431 result = ValidationResult(is_valid=True) 

432 

433 for rule in self.validation_rules: 

434 try: 

435 is_valid = rule.validate(data, context) 

436 

437 if not is_valid: 

438 error_message = rule.get_error_message(data) 

439 result.add_error(f"[{rule.name}] {error_message}") 

440 

441 # Stop on first error if fail_fast is enabled 

442 # fail_fastが有効な場合は最初のエラーで停止 

443 if self.config.fail_fast: 

444 break 

445 

446 except Exception as e: 

447 error_message = f"Rule '{rule.name}' execution error: {e}" 

448 result.add_error(error_message) 

449 logger.warning(error_message) 

450 

451 if self.config.fail_fast: 

452 break 

453 

454 return result 

455 

456 def add_rule(self, rule: ValidationRule): 

457 """ 

458 Add a validation rule to the agent. 

459 エージェントに検証ルールを追加します。 

460 """ 

461 self.validation_rules.append(rule) 

462 

463 def get_rules(self) -> List[ValidationRule]: 

464 """ 

465 Get all validation rules. 

466 全ての検証ルールを取得します。 

467 """ 

468 return self.validation_rules.copy() 

469 

470 

471# Utility functions for creating common validators 

472# 一般的なバリデーターを作成するためのユーティリティ関数 

473 

474def create_email_validator(name: str = "email_validator") -> ValidatorAgent: 

475 """ 

476 Create a validator for email format. 

477 メール形式用のバリデーターを作成します。 

478 """ 

479 config = ValidatorConfig( 

480 name=name, 

481 rules=[ 

482 {"type": "required", "name": "email_required"}, 

483 {"type": "email", "name": "email_format"} 

484 ] 

485 ) 

486 return ValidatorAgent(config) 

487 

488 

489def create_required_validator(name: str = "required_validator") -> ValidatorAgent: 

490 """ 

491 Create a validator for required fields. 

492 必須フィールド用のバリデーターを作成します。 

493 """ 

494 config = ValidatorConfig( 

495 name=name, 

496 rules=[{"type": "required", "name": "required_check"}] 

497 ) 

498 return ValidatorAgent(config) 

499 

500 

501def create_length_validator(min_length: int = None, max_length: int = None, 

502 name: str = "length_validator") -> ValidatorAgent: 

503 """ 

504 Create a validator for length constraints. 

505 長さ制約用のバリデーターを作成します。 

506 """ 

507 config = ValidatorConfig( 

508 name=name, 

509 rules=[{ 

510 "type": "length", 

511 "name": "length_check", 

512 "min_length": min_length, 

513 "max_length": max_length 

514 }] 

515 ) 

516 return ValidatorAgent(config) 

517 

518 

519def create_custom_validator(validation_func: Callable[[Any, Context], bool], 

520 error_message: str, name: str = "custom_validator") -> ValidatorAgent: 

521 """ 

522 Create a validator with custom validation function. 

523 カスタム検証関数を持つバリデーターを作成します。 

524 """ 

525 config = ValidatorConfig(name=name) 

526 custom_rule = CustomFunctionRule(validation_func, error_message, "custom_rule") 

527 return ValidatorAgent(config, [custom_rule])