interactor
64class Interactor: 65 def __init__( 66 self, 67 base_url: Optional[str] = None, 68 api_key: Optional[str] = None, 69 model: str = "openai:gpt-4o-mini", 70 fallback_model = "ollama:mistral-nemo:latest", 71 tools: Optional[bool] = True, 72 stream: bool = True, 73 quiet: bool = False, 74 context_length: int = 128000, 75 max_retries: int = 3, 76 retry_delay: float = 1.0, 77 log_path: Optional[str] = None, 78 log_level: Optional[str] = "ERROR", 79 raw: Optional[bool] = False, 80 session_enabled: bool = False, 81 session_id: Optional[str] = None, 82 session_path: Optional[str] = None 83 ): 84 """Initialize the universal AI interaction client. 85 86 Args: 87 base_url: Optional base URL for the API. If None, uses the provider's default URL. 88 api_key: Optional API key. If None, attempts to use environment variables based on provider. 89 model: Model identifier in format "provider:model_name" (e.g., "openai:gpt-4o-mini"). 90 tools: Enable (True) or disable (False) tool calling; None for auto-detection based on model support. 91 stream: Enable (True) or disable (False) streaming responses. 92 context_length: Maximum number of tokens to maintain in conversation history. 93 max_retries: Maximum number of retries for failed API calls. 94 retry_delay: Initial delay (in seconds) for exponential backoff retries. 95 session_enabled: Enable persistent session support. 96 session_id: Optional session ID to load messages from. 97 98 Raises: 99 ValueError: If provider is not supported or API key is missing for non-Ollama providers. 100 """ 101 self.system = "You are a helpful Assistant." 102 self.raw = raw 103 self.quiet = quiet 104 self.models = [] 105 106 # Configure Logging 107 level_map = { 108 "DEBUG": logging.DEBUG, 109 "INFO": logging.INFO, 110 "WARNING": logging.WARNING, 111 "ERROR": logging.ERROR, 112 "CRITICAL": logging.CRITICAL 113 } 114 115 self.logger = logging.getLogger(f"InteractorLogger_{id(self)}") 116 self.logger.setLevel(level_map[log_level]) 117 118 self.providers = { 119 "openai": { 120 "sdk": "openai", 121 "base_url": "https://api.openai.com/v1", 122 "api_key": api_key or os.getenv("OPENAI_API_KEY") or None 123 }, 124 "ollama": { 125 "sdk": "openai", 126 "base_url": "http://localhost:11434/v1", 127 "api_key": api_key or "ollama" 128 }, 129 "nvidia": { 130 "sdk": "openai", 131 "base_url": "https://integrate.api.nvidia.com/v1", 132 "api_key": api_key or os.getenv("NVIDIA_API_KEY") or None 133 }, 134 "google": { 135 "sdk": "openai", 136 "base_url": "https://generativelanguage.googleapis.com/v1beta/openai", 137 "api_key": api_key or os.getenv("GEMINI_API_KEY") or None 138 }, 139 "anthropic": { 140 "sdk": "anthropic", 141 "base_url": "https://api.anthropic.com/v1", 142 "api_key": api_key or os.getenv("ANTHROPIC_API_KEY") or None 143 }, 144 "mistral": { 145 "sdk": "openai", 146 "base_url": "https://api.mistral.ai/v1", 147 "api_key": api_key or os.getenv("MISTRAL_API_KEY") or None 148 }, 149 "deepseek": { 150 "sdk": "openai", 151 "base_url": "https://api.deepseek.com", 152 "api_key": api_key or os.getenv("DEEPSEEK_API_KEY") or None 153 }, 154 } 155 """ 156 "grok": { 157 "sdk": "grok", 158 "base_url": "https://api.x.ai/v1", 159 "api_key": api_key or os.getenv("GROK_API_KEY") or None 160 } 161 } 162 """ 163 164 # Console log handler (always enabled at WARNING+) 165 if not self.logger.handlers: 166 console_handler = logging.StreamHandler(sys.stdout) 167 console_handler.setLevel(logging.WARNING) 168 console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) 169 self.logger.addHandler(console_handler) 170 171 self._log_enabled = False 172 if log_path: 173 file_handler = logging.FileHandler(log_path) 174 file_handler.setLevel(logging.DEBUG) 175 file_handler.setFormatter(logging.Formatter( 176 "%(asctime)s - %(levelname)s - %(message)s", 177 datefmt="%Y-%m-%d %H:%M:%S" 178 )) 179 self.logger.addHandler(file_handler) 180 self._log_enabled = True 181 182 183 self.token_estimate = 0 184 self.last_token_estimate = 0 185 self.stream = stream 186 self.tools = [] 187 self.session_history = [] 188 self.history = [] 189 self.context_length = context_length 190 self.encoding = None 191 self.max_retries = max_retries 192 self.retry_delay = retry_delay 193 self.reveal_tool = [] 194 self.fallback_model = fallback_model 195 self.sdk = None 196 197 # Session support 198 self.session_enabled = session_enabled 199 self.session_id = session_id 200 self._last_session_id = session_id 201 self.session = Session(directory=session_path) if session_enabled else None 202 203 204 if model is None: 205 model = "openai:gpt-4o-mini" 206 207 # Initialize model + encoding 208 self._setup_client(model, base_url, api_key) 209 self.tools_enabled = self.tools_supported if tools is None else tools and self.tools_supported 210 self._setup_encoding() 211 self.messages_add(role="system", content=self.system) 212 213 214 215 def _log(self, message: str, level: str = "info"): 216 """Log a message to the configured logging handlers. 217 218 This internal method handles logging to both console and file handlers 219 if configured. It respects the logging level and only logs if logging 220 is enabled. 221 222 Args: 223 message (str): The message to log 224 level (str): Logging level - one of "debug", "info", "warning", "error" 225 """ 226 if self._log_enabled: 227 getattr(self.logger, level)(message) 228 229 230 def _setup_client( 231 self, 232 model: Optional[str] = None, 233 base_url: Optional[str] = None, 234 api_key: Optional[str] = None 235 ): 236 """Initialize or reconfigure the Interactor for the given model and SDK. 237 238 Ensures idempotent setup, assigns SDK-specific clients and tool handling logic, 239 and normalizes history to match the provider-specific message schema. 240 """ 241 if not model: 242 raise ValueError("Model must be specified as 'provider:model_name'") 243 244 provider, model_name = model.split(":", 1) 245 246 if not hasattr(self, "session_history"): 247 self.session_history = [] 248 249 # Skip setup if nothing has changed (client may not yet exist on first call) 250 if ( 251 hasattr(self, "client") 252 and self.client 253 and self.provider == provider 254 and self.model == model_name 255 and self.base_url == (base_url or self.base_url) 256 ): 257 return 258 259 if provider not in self.providers: 260 raise ValueError(f"Unsupported provider: {provider}. Supported: {list(self.providers.keys())}") 261 262 # Load provider configuration 263 provider_config = self.providers[provider] 264 self.sdk = provider_config.get("sdk", "openai") 265 self.provider = provider 266 self.model = model_name 267 self.base_url = base_url or provider_config["base_url"] 268 effective_api_key = api_key or provider_config["api_key"] 269 270 if not effective_api_key and provider != "ollama": 271 raise ValueError(f"API key not provided and not found in environment for {provider.upper()}_API_KEY") 272 273 # SDK-specific configuration 274 if self.sdk == "openai": 275 self.client = openai.OpenAI(base_url=self.base_url, api_key=effective_api_key) 276 self.async_client = openai.AsyncOpenAI(base_url=self.base_url, api_key=effective_api_key) 277 self.sdk_runner = self._openai_runner 278 self.tool_key = "tool_call_id" 279 280 elif self.sdk == "anthropic": 281 self.client = anthropic.Anthropic(api_key=effective_api_key) 282 self.async_client = anthropic.AsyncAnthropic(api_key=effective_api_key) 283 self.sdk_runner = self._anthropic_runner 284 self.tool_key = "tool_use_id" 285 286 else: 287 raise ValueError(f"Unsupported SDK type: {self.sdk}") 288 289 # Determine tool support 290 self.tools_supported = self._check_tool_support() 291 if not self.tools_supported: 292 self.logger.warning(f"Tool calling not supported for {provider}:{model_name}") 293 294 # Normalize session history to match SDK after any provider/model change 295 self._normalizer(force=True) 296 297 self._log(f"[MODEL] Switched to {provider}:{model_name}") 298 299 300 def _check_tool_support(self) -> bool: 301 """Determine if the current model supports tool calling. 302 303 Returns: 304 bool: True if tools are supported for the active provider/model, False otherwise. 305 """ 306 try: 307 if self.sdk == "openai": 308 response = self.client.chat.completions.create( 309 model=self.model, 310 messages=[{"role": "user", "content": "Test tool support."}], 311 stream=False, 312 tools=[{ 313 "type": "function", 314 "function": { 315 "name": "test_tool", 316 "description": "Check tool support", 317 "parameters": { 318 "type": "object", 319 "properties": { 320 "query": {"type": "string"} 321 }, 322 "required": ["query"] 323 } 324 } 325 }], 326 tool_choice="auto" 327 ) 328 message = response.choices[0].message 329 return bool(message.tool_calls and len(message.tool_calls) > 0) 330 331 elif self.sdk == "anthropic": 332 # For Claude models, we pre-define support based on model ID 333 # Known tool-supporting Claude models 334 claude_tool_models = ["claude-3-opus", "claude-3-sonnet", "claude-3-haiku", 335 "claude-3.5-sonnet", "claude-3.7-sonnet"] 336 337 # Check if the current model supports tools 338 for supported_model in claude_tool_models: 339 if supported_model in self.model.lower(): 340 self._log(f"[TOOLS] Anthropic model {self.model} is known to support tools") 341 return True 342 343 # If not explicitly supported, try to test 344 try: 345 _ = self.client.messages.create( 346 model=self.model, 347 messages=[{"role": "user", "content": "What's the weather?"}], 348 tools=[{ 349 "name": "test_tool", 350 "description": "Check tool support", 351 "input_schema": { 352 "type": "object", 353 "properties": { 354 "query": {"type": "string"} 355 }, 356 "required": ["query"] 357 } 358 }], 359 max_tokens=10 360 ) 361 return True 362 except anthropic.BadRequestError as e: 363 error_msg = str(e).lower() 364 if "tool" in error_msg and "not supported" in error_msg: 365 self._log(f"[TOOLS] Anthropic model {self.model} does not support tools: {e}") 366 return False 367 if "not a supported tool field" in error_msg: 368 self._log(f"[TOOLS] Anthropic API rejected tool format: {e}") 369 return False 370 raise 371 except Exception as e: 372 self._log(f"[TOOLS] Unexpected error testing tool support: {e}", level="error") 373 return False 374 375 else: 376 self.logger.warning(f"Tool support check not implemented for SDK '{self.sdk}'") 377 return False 378 379 except Exception as e: 380 self.logger.error(f"Tool support check failed for {self.provider}:{self.model} — {e}") 381 return False 382 383 384 def add_function( 385 self, 386 external_callable: Callable, 387 name: Optional[str] = None, 388 description: Optional[str] = None, 389 override: bool = False, 390 disabled: bool = False, 391 schema_extensions: Optional[Dict[str, Any]] = None 392 ): 393 """ 394 Register a function for LLM tool calling with full type hints and metadata. 395 396 Args: 397 external_callable (Callable): The function to register. 398 name (Optional[str]): Optional custom name. Defaults to function's __name__. 399 description (Optional[str]): Optional custom description. Defaults to first line of docstring. 400 override (bool): If True, replaces an existing tool with the same name. 401 disabled (bool): If True, registers the function in a disabled state. 402 schema_extensions (Optional[Dict[str, Any]]): Optional dictionary mapping parameter names to 403 schema extensions that override or add to the auto-generated schema. 404 405 Raises: 406 ValueError: If the callable is invalid or duplicate name found without override. 407 408 Example: 409 interactor.add_function( 410 my_tool, 411 override=True, 412 disabled=False, 413 schema_extensions={ 414 "param1": {"minimum": 0, "maximum": 100}, 415 "param2": {"format": "email"} 416 } 417 ) 418 """ 419 def _python_type_to_schema(ptype: Any) -> dict: 420 """Convert a Python type annotation to OpenAI-compatible JSON Schema.""" 421 # Handle None case 422 if ptype is None: 423 return {"type": "null"} 424 425 # Get the origin and arguments of the type 426 origin = get_origin(ptype) 427 args = get_args(ptype) 428 429 # Handle Union types (including Optional) 430 if origin is Union: 431 # Check for Optional (Union with None) 432 none_type = type(None) 433 if none_type in args: 434 non_none = [a for a in args if a is not none_type] 435 if len(non_none) == 1: 436 inner = _python_type_to_schema(non_none[0]) 437 inner_copy = inner.copy() 438 inner_copy["nullable"] = True 439 return inner_copy 440 # Multiple types excluding None 441 types = [_python_type_to_schema(a) for a in non_none] 442 return {"anyOf": types, "nullable": True} 443 # Regular Union without None 444 return {"anyOf": [_python_type_to_schema(a) for a in args]} 445 446 # Handle List and similar container types 447 if origin in (list, List): 448 item_type = args[0] if args else Any 449 if item_type is Any: 450 return {"type": "array"} 451 return {"type": "array", "items": _python_type_to_schema(item_type)} 452 453 # Handle Dict types with typing info 454 if origin in (dict, Dict): 455 if not args or len(args) != 2: 456 return {"type": "object"} 457 458 key_type, val_type = args 459 # We can only really use val_type in JSON Schema 460 if val_type is not Any and val_type is not object: 461 return { 462 "type": "object", 463 "additionalProperties": _python_type_to_schema(val_type) 464 } 465 return {"type": "object"} 466 467 # Handle Literal types for enums 468 if origin is Literal: 469 values = args 470 # Try to determine type from values 471 if all(isinstance(v, str) for v in values): 472 return {"type": "string", "enum": list(values)} 473 elif all(isinstance(v, bool) for v in values): 474 return {"type": "boolean", "enum": list(values)} 475 elif all(isinstance(v, (int, float)) for v in values): 476 return {"type": "number", "enum": list(values)} 477 else: 478 # Mixed types, use anyOf 479 return {"anyOf": [{"type": _get_json_type(v), "enum": [v]} for v in values]} 480 481 # Handle basic types 482 if ptype is str: 483 return {"type": "string"} 484 if ptype is int: 485 return {"type": "integer"} 486 if ptype is float: 487 return {"type": "number"} 488 if ptype is bool: 489 return {"type": "boolean"} 490 491 # Handle common datetime types 492 if ptype is datetime: 493 return {"type": "string", "format": "date-time"} 494 if ptype is date: 495 return {"type": "string", "format": "date"} 496 497 # Handle UUID 498 if ptype is uuid.UUID: 499 return {"type": "string", "format": "uuid"} 500 501 # Default to object for any other types 502 return {"type": "object"} 503 504 def _get_json_type(value): 505 """Get the JSON Schema type name for a Python value. 506 507 This helper function maps Python types to their corresponding 508 JSON Schema type names. It handles basic types and provides 509 sensible defaults for complex types. 510 511 Args: 512 value: The Python value to get the JSON type for 513 514 Returns: 515 str: The JSON Schema type name ('string', 'number', 'boolean', 516 'array', 'object', or 'object' as default) 517 """ 518 if isinstance(value, str): 519 return "string" 520 elif isinstance(value, bool): 521 return "boolean" 522 elif isinstance(value, int) or isinstance(value, float): 523 return "number" 524 elif isinstance(value, list): 525 return "array" 526 elif isinstance(value, dict): 527 return "object" 528 else: 529 return "object" # Default 530 531 def _parse_param_docs(docstring: str) -> dict: 532 """Extract parameter descriptions from a docstring.""" 533 if not docstring: 534 return {} 535 536 lines = docstring.splitlines() 537 param_docs = {} 538 current_param = None 539 in_params = False 540 541 # Regular expressions for finding parameter sections and param lines 542 param_section_re = re.compile(r"^(Args|Parameters):\s*$") 543 param_line_re = re.compile(r"^\s{4}(\w+)\s*(?:\([^\)]*\))?:\s*(.*)") 544 545 for line in lines: 546 # Check if we're entering the parameters section 547 if param_section_re.match(line.strip()): 548 in_params = True 549 continue 550 551 if in_params: 552 # Skip empty lines 553 if not line.strip(): 554 continue 555 556 # Check for a parameter definition line 557 match = param_line_re.match(line) 558 if match: 559 current_param = match.group(1) 560 param_docs[current_param] = match.group(2).strip() 561 # Check for continuation of a parameter description 562 elif current_param and line.startswith(" " * 8): 563 param_docs[current_param] += " " + line.strip() 564 # If we see a line that doesn't match our patterns, we're out of the params section 565 else: 566 current_param = None 567 568 return param_docs 569 570 # Start of main function logic 571 572 # Skip if tools are disabled 573 if not self.tools_enabled: 574 return 575 576 # Validate input callable 577 if not external_callable: 578 raise ValueError("A valid external callable must be provided.") 579 580 # Set function name, either from parameter or from callable's __name__ 581 function_name = name or external_callable.__name__ 582 583 # Try to get docstring and extract description 584 try: 585 docstring = inspect.getdoc(external_callable) 586 description = description or (docstring.split("\n")[0].strip() if docstring else "No description provided.") 587 except Exception as e: 588 self._log(f"[TOOL] Warning: Could not extract docstring from {function_name}: {e}", level="warning") 589 docstring = "" 590 description = description or "No description provided." 591 592 # Extract parameter documentation from docstring 593 param_docs = _parse_param_docs(docstring) 594 595 # Handle conflicts with existing functions 596 if override: 597 self.delete_function(function_name) 598 elif any(t["function"]["name"] == function_name for t in self.tools): 599 raise ValueError(f"Function '{function_name}' is already registered. Use override=True to replace.") 600 601 # Try to get function signature for parameter info 602 try: 603 signature = inspect.signature(external_callable) 604 except (ValueError, TypeError) as e: 605 raise ValueError(f"Cannot inspect callable '{function_name}': {e}") 606 607 # Process parameters to build schema 608 properties = {} 609 required = [] 610 611 for param_name, param in signature.parameters.items(): 612 # Skip self, cls parameters for instance/class methods 613 if param_name in ("self", "cls") and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: 614 continue 615 616 # Get parameter annotation, defaulting to Any 617 annotation = param.annotation if param.annotation != inspect.Parameter.empty else Any 618 619 try: 620 # Convert Python type to JSON Schema 621 schema = _python_type_to_schema(annotation) 622 623 # Add description from docstring or create a default one 624 schema["description"] = param_docs.get(param_name, f"{param_name} parameter") 625 626 # Add to properties 627 properties[param_name] = schema 628 629 # If no default value is provided, parameter is required 630 if param.default == inspect.Parameter.empty: 631 required.append(param_name) 632 self._log(f"[TOOL] Parameter '{param_name}' is required", level="debug") 633 else: 634 self._log(f"[TOOL] Parameter '{param_name}' has default value: {param.default}", level="debug") 635 636 except Exception as e: 637 self._log(f"[TOOL] Error processing parameter {param_name} for {function_name}: {e}", level="error") 638 # Add a basic object schema as fallback 639 properties[param_name] = { 640 "type": "string", # Default to string instead of object for better compatibility 641 "description": f"{param_name} parameter (type conversion failed)" 642 } 643 644 # For parameters with no default value, mark as required even if processing failed 645 if param.default == inspect.Parameter.empty: 646 required.append(param_name) 647 self._log(f"[TOOL] Parameter '{param_name}' marked as required despite conversion failure", level="debug") 648 649 # Apply schema extensions if provided 650 if schema_extensions: 651 for param_name, extensions in schema_extensions.items(): 652 if param_name in properties: 653 properties[param_name].update(extensions) 654 655 # Create parameters object with proper placement of 'required' field 656 parameters = { 657 "type": "object", 658 "properties": properties, 659 } 660 661 # Only add required field if there are required parameters 662 if required: 663 parameters["required"] = required 664 665 # Build the final tool specification 666 tool_spec = { 667 "type": "function", 668 "function": { 669 "name": function_name, 670 "description": description, 671 "parameters": parameters 672 } 673 } 674 675 # Set disabled flag if requested 676 if disabled: 677 tool_spec["function"]["disabled"] = True 678 679 # Add to tools list 680 self.tools.append(tool_spec) 681 682 # Make the function available as an attribute on the instance 683 setattr(self, function_name, external_callable) 684 685 # Log the registration with detailed information 686 self._log(f"[TOOL] Registered function '{function_name}' with {len(properties)} parameters", level="info") 687 if required: 688 self._log(f"[TOOL] Required parameters: {required}", level="info") 689 690 return function_name # Return the name for reference 691 692 693 def disable_function(self, name: str) -> bool: 694 """ 695 Disable a registered tool function by name. 696 697 This marks the function as inactive for tool calling without removing it from the internal registry. 698 The function remains visible in the tool listing but is skipped during tool selection by the LLM. 699 700 Args: 701 name (str): The name of the function to disable. 702 703 Returns: 704 bool: True if the function was found and disabled, False otherwise. 705 706 Example: 707 interactor.disable_function("extract_text") 708 """ 709 for tool in self.tools: 710 if tool["function"]["name"] == name: 711 tool["function"]["disabled"] = True 712 return True 713 return False 714 715 716 def enable_function(self, name: str) -> bool: 717 """ 718 Re-enable a previously disabled tool function by name. 719 720 This removes the 'disabled' flag from a tool function, making it available again for LLM use. 721 722 Args: 723 name (str): The name of the function to enable. 724 725 Returns: 726 bool: True if the function was found and enabled, False otherwise. 727 728 Example: 729 interactor.enable_function("extract_text") 730 """ 731 for tool in self.tools: 732 if tool["function"]["name"] == name: 733 tool["function"].pop("disabled", None) 734 return True 735 return False 736 737 738 def delete_function(self, name: str) -> bool: 739 """ 740 Permanently remove a registered tool function from the Interactor. 741 742 This deletes both the tool metadata and the callable attribute, making it fully inaccessible 743 from the active session. Useful for dynamically trimming the toolset. 744 745 Args: 746 name (str): The name of the function to delete. 747 748 Returns: 749 bool: True if the function was found and removed, False otherwise. 750 751 Example: 752 interactor.delete_function("extract_text") 753 """ 754 before = len(self.tools) 755 self.tools = [tool for tool in self.tools if tool["function"]["name"] != name] 756 if hasattr(self, name): 757 delattr(self, name) 758 return len(self.tools) < before 759 760 761 def list_functions(self) -> List[Dict[str, Any]]: 762 """Get the list of registered functions for tool calling. 763 764 Returns: 765 List[Dict[str, Any]]: List of registered functions. 766 """ 767 return self.tools 768 769 770 def list_models(self, update=False): 771 """list available models""" 772 if update: 773 self._update_available_models() 774 775 return self.models 776 777 def _update_available_models( 778 self, 779 providers: Optional[Union[str, List[str]]] = None, 780 filter: Optional[str] = None 781 ) -> List[str]: 782 """Retrieve available models from configured providers. 783 784 Args: 785 providers: Provider name or list of provider names. If None, all are queried. 786 filter: Optional regex to filter model names. 787 788 Returns: 789 List[str]: Sorted list of "provider:model_id" strings. 790 """ 791 models = [] 792 793 if providers is None: 794 providers_to_list = self.providers 795 elif isinstance(providers, str): 796 providers_to_list = {providers: self.providers.get(providers)} 797 elif isinstance(providers, list): 798 providers_to_list = {p: self.providers.get(p) for p in providers} 799 else: 800 return [] 801 802 invalid_providers = [p for p in providers_to_list if p not in self.providers or self.providers[p] is None] 803 if invalid_providers: 804 self.logger.error(f"Invalid providers: {invalid_providers}") 805 return [] 806 807 regex_pattern = None 808 if filter: 809 try: 810 regex_pattern = re.compile(filter, re.IGNORECASE) 811 except re.error as e: 812 self.logger.error(f"Invalid regex pattern: {e}") 813 return [] 814 815 for provider_name, config in providers_to_list.items(): 816 sdk = config.get("sdk", "openai") 817 base_url = config.get("base_url") 818 api_key = config.get("api_key") 819 820 try: 821 if sdk == "openai": 822 client = openai.OpenAI(api_key=api_key, base_url=base_url) 823 response = client.models.list() 824 for model in response.data: 825 model_id = f"{provider_name}:{model.id}" 826 if not regex_pattern or regex_pattern.search(model_id): 827 models.append(model_id) 828 829 elif sdk == "anthropic": 830 client = Anthropic(api_key=api_key) 831 response = client.models.list() 832 for model in response: 833 model_id = f"{provider_name}:{model.id}" 834 if not regex_pattern or regex_pattern.search(model_id): 835 models.append(model_id) 836 else: 837 self.logger.warning(f"SDK '{sdk}' for provider '{provider_name}' is not supported by list_models()") 838 839 except Exception as e: 840 self.logger.warning(f"Failed to list models for {provider_name}: {e}") 841 842 self.models = models 843 844 return sorted(models, key=str.lower) 845 846 847 async def _retry_with_backoff(self, func: Callable, *args, **kwargs): 848 """Execute a function with exponential backoff retry logic. 849 850 This method implements a robust retry mechanism for API calls with 851 exponential backoff. It handles rate limits, connection errors, and 852 other transient failures. If all retries fail, it will attempt to 853 switch to a fallback model if configured. 854 855 Args: 856 func (Callable): The async function to execute 857 *args: Positional arguments to pass to the function 858 **kwargs: Keyword arguments to pass to the function 859 860 Returns: 861 The result of the function call if successful 862 863 Raises: 864 Exception: If all retries fail and no fallback model is available 865 """ 866 for attempt in range(self.max_retries + 1): 867 try: 868 return await func(*args, **kwargs) 869 870 except (RateLimitError, APIConnectionError, aiohttp.ClientError) as e: 871 if attempt == self.max_retries: 872 model_key = f"{self.provider}:{self.model}" 873 if self.fallback_model and model_key != self.fallback_model: 874 print(f"[yellow]Model '{model_key}' failed. Switching to fallback: {self.fallback_model}[/yellow]") 875 self._setup_client(self.fallback_model) 876 self._setup_encoding() 877 self._normalizer() 878 return await func(*args, **kwargs) # retry once with fallback model 879 else: 880 self.logger.error(f"All {self.max_retries} retries failed: {e}") 881 raise 882 883 delay = self.retry_delay * (2 ** attempt) 884 self.logger.warning(f"Retry {attempt + 1}/{self.max_retries} after {delay}s due to {e}") 885 self._log(f"[RETRY] Attempt {attempt + 1}/{self.max_retries} failed: {e}", level="warning") 886 await asyncio.sleep(delay) 887 888 except OpenAIError as e: 889 self.logger.error(f"OpenAI error: {e}") 890 raise 891 892 except Exception as e: 893 self.logger.error(f"Unexpected error: {e}") 894 raise 895 896 897 def interact( 898 self, 899 user_input: Optional[str], 900 quiet: bool = False, 901 tools: bool = True, 902 stream: bool = True, 903 markdown: bool = False, 904 model: Optional[str] = None, 905 output_callback: Optional[Callable[[str], None]] = None, 906 session_id: Optional[str] = None, 907 raw: Optional[bool] = None, 908 tool_suppress: bool = True, 909 timeout: float = 60.0 910 ) -> Union[Optional[str], "TokenStream"]: 911 """Main universal gateway for all LLM interaction. 912 913 This function serves as the single entry point for all interactions with the language model. 914 When `raw=False` (default), it handles the interaction internally and returns the full response. 915 When `raw=True`, it returns a context manager that yields chunks of the response for custom handling. 916 917 Args: 918 user_input: Text input from the user. 919 quiet: If True, don't print status info or progress. 920 tools: Enable (True) or disable (False) tool calling. 921 stream: Enable (True) or disable (False) streaming responses. 922 markdown: If True, renders content as markdown. 923 model: Optional model override. 924 output_callback: Optional callback to handle the output. 925 session_id: Optional session ID to load messages from. 926 raw: If True, return a context manager instead of handling the interaction internally. 927 If None, use the class-level setting from __init__. 928 tool_suppress: If True and raw=True, filter out tool-related status messages. 929 timeout: Maximum time in seconds to wait for the stream to complete when raw=True. 930 931 Returns: 932 If raw=False: The complete response from the model as a string, or None if there was an error. 933 If raw=True: A context manager that yields chunks of the response as they arrive. 934 935 Example with default mode: 936 response = ai.interact("Tell me a joke") 937 938 Example with raw mode: 939 with ai.interact("Tell me a joke", raw=True) as stream: 940 for chunk in stream: 941 print(chunk, end="", flush=True) 942 """ 943 if not user_input: 944 return None 945 946 if quiet or self.quiet: 947 markdown = False 948 stream = False 949 950 # Determine if we should use raw mode 951 # If raw parameter is explicitly provided, use that; otherwise use class setting 952 use_raw = self.raw if raw is None else raw 953 954 # If raw mode is requested, delegate to interact_raw 955 if use_raw: 956 return self._interact_raw( 957 user_input=user_input, 958 tools=tools, 959 model=model, 960 session_id=session_id, 961 tool_suppress=tool_suppress, 962 timeout=timeout 963 ) 964 965 # Setup model if specified 966 if model: 967 self._setup_client(model) 968 self._setup_encoding() 969 970 # Session handling 971 if self.session_enabled and session_id: 972 self.session_id = session_id 973 self.session_load(session_id) 974 975 # Add user message using messages_add 976 self.messages_add(role="user", content=user_input) 977 978 # Log token count estimate 979 token_count = self._count_tokens(self.history) 980 if not quiet: 981 print(f"[dim]Estimated tokens in context: {token_count} / {self.context_length}[/dim]") 982 983 # Make sure we have enough context space 984 if token_count > self.context_length: 985 if self._cycle_messages(): 986 if not quiet: 987 print("[red]Context window exceeded. Cannot proceed.[/red]") 988 return None 989 990 # Log user input 991 self._log(f"[USER] {user_input}") 992 993 # Handle the actual interaction with complete streaming for all responses 994 result = asyncio.run(self._interact_async_core( 995 user_input=user_input, 996 quiet=quiet, 997 tools=tools, 998 stream=stream, 999 markdown=markdown, 1000 output_callback=output_callback 1001 )) 1002 1003 # Log completion for this interaction 1004 self._log(f"[INTERACTION] Completed with {len(self.history)} total messages") 1005 1006 return result 1007 1008 1009 def _interact_raw( 1010 self, 1011 user_input: Optional[str], 1012 tools: bool = True, 1013 model: Optional[str] = None, 1014 session_id: Optional[str] = None, 1015 tool_suppress: bool = True, 1016 timeout: float = 60.0 1017 ): 1018 """ 1019 Low-level function that returns a raw stream of tokens from the model. 1020 1021 This method works as a context manager that yields a generator of streaming tokens. 1022 The caller is responsible for handling the output stream. Typically, this is used 1023 indirectly through interact() with raw=True. 1024 1025 Args: 1026 user_input: Text input from the user. 1027 tools: Enable (True) or disable (False) tool calling. 1028 model: Optional model override. 1029 session_id: Optional session ID to load messages from. 1030 tool_suppress: If True, filter out tool-related status messages. 1031 timeout: Maximum time in seconds to wait for the stream to complete. 1032 1033 Returns: 1034 A context manager that yields a stream of tokens. 1035 1036 Example: 1037 with ai.interact_raw("Hello world") as stream: 1038 for chunk in stream: 1039 print(chunk, end="", flush=True) 1040 """ 1041 if not user_input: 1042 return None 1043 1044 # Setup model if specified 1045 if model: 1046 self._setup_client(model) 1047 self._setup_encoding() 1048 1049 # Session handling 1050 if self.session_enabled and session_id: 1051 self.session_id = session_id 1052 self.session_load(session_id) 1053 1054 # Add user message 1055 self.messages_add(role="user", content=user_input) 1056 1057 # Log token count estimate 1058 token_count = self._count_tokens(self.history) 1059 self._log(f"[STREAM] Estimated tokens in context: {token_count} / {self.context_length}") 1060 1061 # Make sure we have enough context space 1062 if token_count > self.context_length: 1063 if self._cycle_messages(): 1064 self._log("[STREAM] Context window exceeded. Cannot proceed.", level="error") 1065 return None 1066 1067 # Log user input 1068 self._log(f"[USER] {user_input}") 1069 1070 # Create a token stream class using a thread-safe queue 1071 class TokenStream: 1072 def __init__(self, interactor, user_input, tools, tool_suppress, timeout): 1073 """Initialize a new TokenStream instance. 1074 1075 This class provides a context manager for streaming token responses 1076 from the AI model. It handles asynchronous token delivery, tool call 1077 suppression, and timeout management. 1078 1079 Args: 1080 interactor: The parent Interactor instance 1081 user_input: The user's input text 1082 tools: Whether tool calling is enabled 1083 tool_suppress: Whether to suppress tool-related status messages 1084 timeout: Maximum time in seconds to wait for stream completion 1085 """ 1086 self.interactor = interactor 1087 self.user_input = user_input 1088 self.tools = tools 1089 self.tool_suppress = tool_suppress 1090 self.timeout = timeout 1091 self.token_queue = queue.Queue() 1092 self.thread = None 1093 self.result = None 1094 self.error = None 1095 self.completed = False 1096 1097 def __enter__(self): 1098 """Enter the context manager and start the streaming process. 1099 1100 This method initializes the streaming worker thread and returns 1101 self for iteration. The worker thread handles the actual API 1102 communication and token delivery. 1103 1104 Returns: 1105 TokenStream: Self for iteration 1106 """ 1107 # Start the thread for async interaction 1108 def stream_worker(): 1109 """Worker thread that handles the streaming interaction. 1110 1111 This internal function runs in a separate thread to handle 1112 the asynchronous API communication and token delivery. 1113 """ 1114 # Define output callback to put tokens in queue 1115 def callback(text): 1116 """Process and queue incoming text tokens. 1117 1118 This internal function handles incoming text chunks, 1119 optionally filtering tool-related messages, and adds 1120 them to the token queue. 1121 1122 Args: 1123 text: The text chunk to process and queue 1124 """ 1125 # Filter out tool messages if requested 1126 if self.tool_suppress: 1127 try: 1128 # Check if this is a tool status message (JSON format) 1129 data = json.loads(text) 1130 if isinstance(data, dict) and data.get("type") == "tool_call": 1131 # Skip this message 1132 return 1133 except (json.JSONDecodeError, TypeError): 1134 # Not JSON or not a dict, continue normally 1135 pass 1136 1137 # Add to queue 1138 self.token_queue.put(text) 1139 1140 # Run the interaction in a new event loop 1141 loop = asyncio.new_event_loop() 1142 asyncio.set_event_loop(loop) 1143 1144 try: 1145 # Run the interaction 1146 self.result = loop.run_until_complete( 1147 self.interactor._interact_async_core( 1148 user_input=self.user_input, 1149 quiet=True, 1150 tools=self.tools, 1151 stream=True, 1152 markdown=False, 1153 output_callback=callback 1154 ) 1155 ) 1156 # Signal successful completion 1157 self.completed = True 1158 except Exception as e: 1159 self.error = str(e) 1160 self.interactor.logger.error(f"Streaming error: {traceback.format_exc()}") 1161 # Add error information to the queue if we haven't yielded anything yet 1162 if self.token_queue.empty(): 1163 self.token_queue.put(f"Error: {str(e)}") 1164 finally: 1165 # Signal end of stream regardless of success/failure 1166 self.token_queue.put(None) 1167 loop.close() 1168 1169 # Start the worker thread 1170 self.thread = threading.Thread(target=stream_worker) 1171 self.thread.daemon = True 1172 self.thread.start() 1173 1174 # Return self for iteration 1175 return self 1176 1177 def __iter__(self): 1178 """Return self as an iterator. 1179 1180 Returns: 1181 TokenStream: Self for iteration 1182 """ 1183 return self 1184 1185 def __next__(self): 1186 """Get the next token from the stream. 1187 1188 This method implements the iterator protocol, retrieving the next 1189 token from the queue with timeout handling. 1190 1191 Returns: 1192 str: The next token from the stream 1193 1194 Raises: 1195 StopIteration: When the stream is complete or times out 1196 """ 1197 # Get next token from queue with timeout to prevent hanging 1198 try: 1199 token = self.token_queue.get(timeout=self.timeout) 1200 if token is None: 1201 # End of stream 1202 raise StopIteration 1203 return token 1204 except queue.Empty: 1205 # Timeout reached 1206 self.interactor.logger.warning(f"Stream timeout after {self.timeout}s") 1207 if not self.completed and not self.error: 1208 # Clean up the thread - it might be stuck 1209 if self.thread and self.thread.is_alive(): 1210 # We can't forcibly terminate a thread in Python, 1211 # but we can report the issue 1212 self.interactor.logger.error("Stream worker thread is hung") 1213 raise StopIteration 1214 1215 def __exit__(self, exc_type, exc_val, exc_tb): 1216 """Exit the context manager and clean up resources. 1217 1218 This method handles cleanup when the context manager is exited, 1219 including thread cleanup and message history updates. 1220 1221 Args: 1222 exc_type: The exception type if an exception was raised 1223 exc_val: The exception value if an exception was raised 1224 exc_tb: The exception traceback if an exception was raised 1225 1226 Returns: 1227 bool: False to not suppress any exceptions 1228 """ 1229 # Clean up resources 1230 if self.thread and self.thread.is_alive(): 1231 self.thread.join(timeout=2.0) 1232 1233 # Add messages to history if successful 1234 if self.completed and self.result and not exc_type: 1235 if isinstance(self.result, str) and self.result != "No response.": 1236 # If we had a successful completion, ensure the result is in the history 1237 last_msg = self.interactor.history[-1] if self.interactor.history else None 1238 if not last_msg or last_msg.get("role") != "assistant" or last_msg.get("content") != self.result: 1239 # Add a clean assistant message to history if not already there 1240 self.interactor.messages_add(role="assistant", content=self.result) 1241 1242 # If there was an error in the stream processing, log it 1243 if self.error: 1244 self.interactor.logger.error(f"Stream processing error: {self.error}") 1245 1246 return False # Don't suppress exceptions 1247 1248 return TokenStream(self, user_input, tools, tool_suppress, timeout) 1249 1250 1251 async def _interact_async_core( 1252 self, 1253 user_input: str, 1254 quiet: bool = False, 1255 tools: bool = True, 1256 stream: bool = True, 1257 markdown: bool = False, 1258 output_callback: Optional[Callable] = None 1259 ) -> str: 1260 """Main SDK-agnostic async execution pipeline with tool call looping support.""" 1261 # Prepare display handler 1262 live = Live(console=console, refresh_per_second=100) if markdown and stream else None 1263 if live: 1264 live.start() 1265 1266 # Initialize variables for iteration tracking 1267 full_content = "" 1268 max_iterations = 5 # Prevent infinite loops 1269 iterations = 0 1270 1271 # Main interaction loop - continues until no more tool calls or max iterations reached 1272 while iterations < max_iterations: 1273 iterations += 1 1274 1275 try: 1276 # Execute the appropriate SDK runner - history is already normalized 1277 response_data = await self.sdk_runner( 1278 model=self.model, 1279 messages=self.history, 1280 stream=stream, 1281 markdown=markdown, 1282 quiet=quiet if iterations == 1 else False, 1283 live=live, 1284 output_callback=output_callback 1285 ) 1286 1287 # Extract response data 1288 content = response_data.get("content", "") 1289 tool_calls = response_data.get("tool_calls", []) 1290 1291 # Log the response data for debugging 1292 self._log(f"[ITERATION {iterations}] Content: {len(content)} chars, Tool calls: {len(tool_calls)}") 1293 1294 # Add content to full response 1295 if iterations == 1: 1296 full_content = content 1297 elif content: 1298 if full_content and content: 1299 full_content += f"\n{content}" 1300 else: 1301 full_content = content 1302 1303 # Add assistant message with or without tool calls 1304 if tool_calls: 1305 # Process each tool call 1306 for call in tool_calls: 1307 # Add assistant message with tool call 1308 tool_info = { 1309 "id": call["id"], 1310 "name": call["function"]["name"], 1311 "arguments": call["function"]["arguments"] 1312 } 1313 1314 # Add the assistant message with tool call 1315 self.messages_add( 1316 role="assistant", 1317 content=content if len(tool_calls) == 1 else "", 1318 tool_info=tool_info 1319 ) 1320 1321 # Execute the tool 1322 call_name = call["function"]["name"] 1323 call_args = call["function"]["arguments"] 1324 call_id = call["id"] 1325 1326 # Stop Rich Live while executing tool calls 1327 live_was_active = True 1328 if live and live.is_started: 1329 live_was_active = True 1330 live.stop() 1331 1332 result = await self._handle_tool_call_async( 1333 function_name=call_name, 1334 function_arguments=call_args, 1335 tool_call_id=call_id, 1336 quiet=quiet, 1337 safe=False, 1338 output_callback=output_callback 1339 ) 1340 1341 # Restart live display if it was active before 1342 if live_was_active and live: 1343 live.start() 1344 1345 # Add tool result message 1346 tool_result_info = { 1347 "id": call_id, 1348 "result": result 1349 } 1350 1351 self.messages_add( 1352 role="tool", 1353 content=result, 1354 tool_info=tool_result_info 1355 ) 1356 else: 1357 # Simple assistant response without tool calls 1358 self.messages_add(role="assistant", content=content) 1359 break # No more tools to process, we're done 1360 1361 # Reset live display if needed 1362 if stream and live: 1363 live.stop() 1364 live = Live(console=console, refresh_per_second=100) 1365 live.start() 1366 1367 except Exception as e: 1368 self.logger.error(f"[{self.sdk.upper()} ERROR] {str(e)}") 1369 self._log(f"[ERROR] Error in interaction loop: {str(e)}", level="error") 1370 if live: 1371 live.stop() 1372 return f"Error: {str(e)}" 1373 1374 # Clean up display 1375 if live: 1376 live.stop() 1377 1378 return full_content or None 1379 1380 1381 async def _openai_runner( 1382 self, 1383 *, 1384 model, 1385 messages, 1386 stream, 1387 markdown=False, 1388 quiet=False, 1389 live=None, 1390 output_callback=None 1391 ): 1392 """Handle OpenAI-specific API interactions and response processing.""" 1393 # Log what we're sending for debugging 1394 self._log(f"[OPENAI REQUEST] Sending request to {model} with {len(self.history)} messages", level="debug") 1395 1396 # Prepare API parameters - history is already normalized by _normalizer 1397 params = { 1398 "model": model, 1399 "messages": self.history, 1400 "stream": stream, 1401 } 1402 1403 # Add tools if enabled 1404 if self.tools_enabled and self.tools_supported: 1405 enabled_tools = self._get_enabled_tools() 1406 if enabled_tools: 1407 params["tools"] = enabled_tools 1408 params["tool_choice"] = "auto" 1409 1410 # Call API with retry handling 1411 try: 1412 response = await self._retry_with_backoff( 1413 self.async_client.chat.completions.create, 1414 **params 1415 ) 1416 except Exception: 1417 self.logger.error(f"[OPENAI ERROR RUNNER]: {traceback.format_exc()}") 1418 raise 1419 1420 assistant_content = "" 1421 tool_calls_dict = {} 1422 1423 # Process streaming response 1424 if stream and hasattr(response, "__aiter__"): 1425 async for chunk in response: 1426 delta = getattr(chunk.choices[0], "delta", None) 1427 1428 # Handle content chunks 1429 if hasattr(delta, "content") and delta.content is not None: 1430 text = delta.content 1431 assistant_content += text 1432 if output_callback: 1433 output_callback(text) 1434 elif live: 1435 live.update(Markdown(assistant_content)) 1436 elif not markdown: 1437 print(text, end="") 1438 1439 # Process tool calls 1440 if hasattr(delta, "tool_calls") and delta.tool_calls: 1441 for tool_call_delta in delta.tool_calls: 1442 index = tool_call_delta.index 1443 if index not in tool_calls_dict: 1444 tool_calls_dict[index] = { 1445 "id": tool_call_delta.id if hasattr(tool_call_delta, "id") else None, 1446 "function": {"name": "", "arguments": ""} 1447 } 1448 1449 function = getattr(tool_call_delta, "function", None) 1450 if function: 1451 name = getattr(function, "name", None) 1452 args = getattr(function, "arguments", "") 1453 if name: 1454 tool_calls_dict[index]["function"]["name"] = name 1455 if args: 1456 tool_calls_dict[index]["function"]["arguments"] += args 1457 if tool_call_delta.id and not tool_calls_dict[index]["id"]: 1458 tool_calls_dict[index]["id"] = tool_call_delta.id 1459 1460 # Make sure the ID is set regardless 1461 if hasattr(tool_call_delta, "id") and tool_call_delta.id and not tool_calls_dict[index]["id"]: 1462 tool_calls_dict[index]["id"] = tool_call_delta.id 1463 1464 if not output_callback and not markdown and not quiet: 1465 print() 1466 1467 # Process non-streaming response 1468 else: 1469 message = response.choices[0].message 1470 assistant_content = message.content or "" 1471 1472 if hasattr(message, "tool_calls") and message.tool_calls: 1473 for i, tool_call in enumerate(message.tool_calls): 1474 tool_calls_dict[i] = { 1475 "id": tool_call.id, 1476 "function": { 1477 "name": tool_call.function.name, 1478 "arguments": tool_call.function.arguments 1479 } 1480 } 1481 1482 if output_callback: 1483 output_callback(assistant_content) 1484 elif not quiet: 1485 print(assistant_content) 1486 1487 # Log tool calls for debugging 1488 if tool_calls_dict: 1489 self._log(f"[OPENAI TOOL CALLS] Found {len(tool_calls_dict)} tool calls", level="debug") 1490 for idx, call in tool_calls_dict.items(): 1491 self._log(f"[OPENAI TOOL CALL {idx}] {call['function']['name']} with ID {call['id']}", level="debug") 1492 1493 # Return standardized response format 1494 return { 1495 "content": assistant_content, 1496 "tool_calls": list(tool_calls_dict.values()) 1497 } 1498 1499 1500 async def _anthropic_runner( 1501 self, 1502 *, 1503 model, 1504 messages, 1505 stream, 1506 markdown=False, 1507 quiet=False, 1508 live=None, 1509 output_callback=None 1510 ): 1511 """Handle Anthropic-specific API interactions and response processing.""" 1512 # Log what we're sending for debugging 1513 self._log(f"[ANTHROPIC REQUEST] Sending request to {model} with {len(self.history)} messages", level="debug") 1514 1515 # Prepare API parameters - history is already normalized by _normalizer 1516 params = { 1517 "model": model, 1518 "messages": self.history, 1519 "max_tokens": 8192, 1520 "system": self.system 1521 } 1522 1523 # Add tools support if needed 1524 if self.tools_enabled and self.tools_supported: 1525 enabled_tools = [] 1526 for tool in self._get_enabled_tools(): 1527 # Extract parameters from OpenAI format 1528 tool_params = tool["function"]["parameters"] 1529 1530 # Create Anthropic-compatible tool definition 1531 format_tool = { 1532 "name": tool["function"]["name"], 1533 "description": tool["function"].get("description", ""), 1534 "input_schema": { 1535 "type": "object", 1536 "properties": tool_params.get("properties", {}) 1537 } 1538 } 1539 1540 # Ensure 'required' is at the correct level for Anthropic (as a direct child of input_schema) 1541 if "required" in tool_params: 1542 format_tool["input_schema"]["required"] = tool_params["required"] 1543 1544 enabled_tools.append(format_tool) 1545 1546 params["tools"] = enabled_tools 1547 1548 assistant_content = "" 1549 tool_calls_dict = {} 1550 1551 try: 1552 # Process streaming response 1553 if stream: 1554 stream_params = params.copy() 1555 stream_params["stream"] = True 1556 1557 stream_response = await self._retry_with_backoff( 1558 self.async_client.messages.create, 1559 **stream_params 1560 ) 1561 1562 content_type = None 1563 async for chunk in stream_response: 1564 chunk_type = getattr(chunk, "type", "unknown") 1565 self._log(f"[ANTHROPIC CHUNK] Type: {chunk_type}", level="debug") 1566 if chunk_type == "content_block_start" and hasattr(chunk.content_block, "type"): 1567 content_type = chunk.content_block.type 1568 if content_type == "tool_use": 1569 tool_id = chunk.content_block.id 1570 tool_name = chunk.content_block.name 1571 tool_input = chunk.content_block.input 1572 tool_calls_dict[tool_id] = { 1573 "id": tool_id, 1574 "function": { 1575 "name": tool_name, 1576 "arguments": "" 1577 } 1578 } 1579 self._log(f"[ANTHROPIC TOOL USE] {tool_name}", level="debug") 1580 1581 # Handle text content 1582 if chunk_type == "content_block_delta" and hasattr(chunk.delta, "text"): 1583 delta = chunk.delta.text 1584 assistant_content += delta 1585 if output_callback: 1586 output_callback(delta) 1587 elif live: 1588 live.update(Markdown(assistant_content)) 1589 elif not markdown: 1590 print(delta, end="") 1591 1592 # Handle complete tool use 1593 elif chunk_type == "content_block_delta" and content_type == "tool_use": 1594 tool_calls_dict[tool_id]["function"]["arguments"] += chunk.delta.partial_json 1595 1596 # Process non-streaming response 1597 else: 1598 # For non-streaming, ensure we don't send the stream parameter 1599 non_stream_params = params.copy() 1600 non_stream_params.pop("stream", None) # Remove stream if it exists 1601 1602 response = await self._retry_with_backoff( 1603 self.async_client.messages.create, 1604 **non_stream_params 1605 ) 1606 1607 # Extract text content 1608 for content_block in response.content: 1609 if content_block.type == "text": 1610 assistant_content += content_block.text 1611 1612 if content_block.type == "tool_use": 1613 tool_id = content_block.id 1614 tool_name = content_block.name 1615 tool_input = content_block.input 1616 tool_calls_dict[tool_id] = { 1617 "id": tool_id, 1618 "function": { 1619 "name": tool_name, 1620 "arguments": tool_input 1621 } 1622 } 1623 self._log(f"[ANTHROPIC TOOL USE] {tool_name}", level="debug") 1624 1625 if output_callback: 1626 output_callback(assistant_content) 1627 elif not quiet: 1628 print(assistant_content) 1629 1630 except Exception as e: 1631 self._log(f"[ANTHROPIC ERROR RUNNER] {traceback.format_exc()}", level="error") 1632 1633 # Return something usable even in case of error 1634 return { 1635 "content": f"Error processing Anthropic response: {str(e)}", 1636 "tool_calls": [] 1637 } 1638 1639 # Return standardized response format 1640 return { 1641 "content": assistant_content, 1642 "tool_calls": list(tool_calls_dict.values()) 1643 } 1644 1645 1646 def _get_enabled_tools(self) -> List[dict]: 1647 """Return the list of currently enabled tool function definitions.""" 1648 return [ 1649 tool for tool in self.tools 1650 if not tool["function"].get("disabled", False) 1651 ] 1652 1653 1654 async def _handle_tool_call_async( 1655 self, 1656 function_name: str, 1657 function_arguments: str, 1658 tool_call_id: str, 1659 quiet: bool = False, 1660 safe: bool = False, 1661 output_callback: Optional[Callable[[str], None]] = None 1662 ) -> str: 1663 """Process a tool call asynchronously and return the result. 1664 1665 Args: 1666 function_name: Name of the function to call. 1667 function_arguments: JSON string containing the function arguments. 1668 tool_call_id: Unique identifier for this tool call. 1669 params: Parameters used for the original API call. 1670 safe: If True, prompts for confirmation before executing the tool call. 1671 output_callback: Optional callback to handle the tool call result. 1672 1673 Returns: 1674 The result of the function call. 1675 1676 Raises: 1677 ValueError: If the function is not found or JSON is invalid. 1678 """ 1679 if isinstance(function_arguments, str): 1680 arguments = json.loads(function_arguments) 1681 else: 1682 arguments = function_arguments 1683 1684 self._log(f"[TOOL:{function_name}] args={arguments}") 1685 1686 func = getattr(self, function_name, None) 1687 if not func: 1688 raise ValueError(f"Function '{function_name}' not found.") 1689 1690 be_quiet = self.quiet if quiet is None else quiet 1691 1692 if not be_quiet: 1693 print(f"\nRunning {function_name}...") 1694 1695 if output_callback: 1696 notification = json.dumps({ 1697 "type": "tool_call", 1698 "tool_name": function_name, 1699 "status": "started" 1700 }) 1701 output_callback(notification) 1702 1703 try: 1704 if safe: 1705 prompt = f"[bold yellow]Proposed tool call:[/bold yellow] {function_name}({json.dumps(arguments, indent=2)})\n[bold cyan]Execute? [y/n]: [/bold cyan]" 1706 confirmed = Confirm.ask(prompt, default=False) 1707 if not confirmed: 1708 command_result = { 1709 "status": "cancelled", 1710 "message": "Tool call aborted by user" 1711 } 1712 print("[red]Tool call cancelled by user[/red]") 1713 else: 1714 loop = asyncio.get_event_loop() 1715 command_result = await loop.run_in_executor(None, lambda: func(**arguments)) 1716 else: 1717 loop = asyncio.get_event_loop() 1718 command_result = await loop.run_in_executor(None, lambda: func(**arguments)) 1719 1720 try: 1721 json.dumps(command_result) 1722 except TypeError as e: 1723 self.logger.error(f"Tool call result not serializable: {e}") 1724 return {"error": "Tool call returned unserializable data."} 1725 1726 if output_callback: 1727 notification = json.dumps({ 1728 "type": "tool_call", 1729 "tool_name": function_name, 1730 "status": "completed" 1731 }) 1732 output_callback(notification) 1733 1734 return command_result 1735 1736 except Exception as e: 1737 self._log(f"[ERROR] Tool execution failed: {e}", level="error") 1738 self.logger.error(f"Error executing tool function '{function_name}': {e}") 1739 return {"error": str(e)} 1740 1741 1742 def _setup_encoding(self): 1743 """Initialize the token encoding system for the current model. 1744 1745 This method sets up the appropriate tokenizer based on the current 1746 model provider. For OpenAI models, it attempts to use the model-specific 1747 tokenizer, falling back to cl100k_base if not available. For other 1748 providers, it uses cl100k_base as a default. 1749 1750 The encoding is used for token counting and context management. 1751 """ 1752 try: 1753 if self.provider == "openai": 1754 try: 1755 self.encoding = tiktoken.encoding_for_model(self.model) 1756 self._log(f"[ENCODING] Loaded tokenizer for OpenAI model: {self.model}") 1757 except: 1758 self.encoding = tiktoken.get_encoding("cl100k_base") 1759 self._log(f"[ENCODING] Fallback to cl100k_base for model: {self.model}") 1760 else: 1761 self.encoding = tiktoken.get_encoding("cl100k_base") 1762 self._log(f"[ENCODING] Defaulting to cl100k_base for non-OpenAI model: {self.model}") 1763 except Exception as e: 1764 self.logger.error(f"Failed to setup encoding: {e}") 1765 self.encoding = tiktoken.get_encoding("cl100k_base") 1766 1767 1768 def _estimate_tokens_tiktoken(self, messages) -> int: 1769 """Rough token count estimate using tiktoken for OpenAI or fallback cases.""" 1770 if not hasattr(self, "encoding") or not self.encoding: 1771 self._setup_encoding() 1772 return sum(len(self.encoding.encode(msg.get("content", ""))) for msg in messages if isinstance(msg.get("content"), str)) 1773 1774 1775 def _count_tokens(self, messages, use_cache=True) -> int: 1776 """Accurately estimate token count for messages including tool calls with caching support. 1777 1778 Args: 1779 messages: List of message objects in either OpenAI or Anthropic format. 1780 use_cache: Whether to use and update the token count cache. 1781 1782 Returns: 1783 int: Estimated token count. 1784 """ 1785 # Setup encoding if needed 1786 if not hasattr(self, "encoding") or not self.encoding: 1787 self._setup_encoding() 1788 1789 # Initialize cache if it doesn't exist 1790 if not hasattr(self, "_token_count_cache"): 1791 self._token_count_cache = {} 1792 1793 # Generate a cache key based on message content hashes 1794 if use_cache: 1795 try: 1796 # Create a cache key using message IDs if available, or content hashes 1797 cache_key_parts = [] 1798 for msg in messages: 1799 if isinstance(msg, dict): 1800 # Try to use stable identifiers for cache key 1801 msg_id = msg.get("id", None) 1802 timestamp = msg.get("timestamp", None) 1803 1804 if msg_id and timestamp: 1805 cache_key_parts.append(f"{msg_id}:{timestamp}") 1806 else: 1807 # Fall back to content-based hash if no stable IDs 1808 content_str = str(msg.get("content", "")) 1809 role = msg.get("role", "unknown") 1810 cache_key_parts.append(f"{role}:{hash(content_str)}") 1811 1812 cache_key = ":".join(cache_key_parts) 1813 if cache_key in self._token_count_cache: 1814 return self._token_count_cache[cache_key] 1815 except Exception as e: 1816 # If caching fails, just continue with normal counting 1817 self._log(f"[TOKEN COUNT] Cache key generation failed: {e}", level="debug") 1818 use_cache = False 1819 1820 # For Claude models, try to use their built-in token counter 1821 if self.sdk == "anthropic": 1822 try: 1823 # Convert messages to Anthropic format if needed 1824 anthropic_messages = [] 1825 for msg in messages: 1826 if msg.get("role") == "system": 1827 continue # System handled separately 1828 1829 if msg.get("role") == "tool": 1830 # Skip tool messages in token count to avoid double-counting 1831 continue 1832 1833 if msg.get("role") == "user" and isinstance(msg.get("content"), list): 1834 # Already in Anthropic format with tool_result 1835 anthropic_messages.append(msg) 1836 elif msg.get("role") in ["user", "assistant"]: 1837 if not msg.get("tool_calls") and not msg.get("tool_use"): 1838 # Simple message 1839 anthropic_messages.append({ 1840 "role": msg.get("role"), 1841 "content": msg.get("content", "") 1842 }) 1843 1844 # Use Anthropic's token counter if messages exist 1845 if anthropic_messages: 1846 response = self.client.messages.count_tokens( 1847 model=self.model, 1848 messages=anthropic_messages, 1849 system=self.system 1850 ) 1851 token_count = response.input_tokens 1852 1853 # Cache the result for future use 1854 if use_cache and 'cache_key' in locals(): 1855 self._token_count_cache[cache_key] = token_count 1856 1857 return token_count 1858 except Exception as e: 1859 # Fall back to our estimation 1860 self._log(f"[TOKEN COUNT] Error using Anthropic token counter: {e}", level="debug") 1861 1862 # More accurate token counting for all message types 1863 num_tokens = 0 1864 1865 # Count tokens for each message 1866 for msg in messages: 1867 # Base token count for message metadata (role + message format) 1868 num_tokens += 4 # Message overhead 1869 1870 # Add tokens for role name 1871 role = msg.get("role", "") 1872 num_tokens += len(self.encoding.encode(role)) 1873 1874 # Count tokens in message content 1875 if isinstance(msg.get("content"), str): 1876 content = msg.get("content", "") 1877 content_tokens = len(self.encoding.encode(content)) 1878 num_tokens += content_tokens 1879 1880 elif isinstance(msg.get("content"), list): 1881 # Handle Anthropic-style content lists 1882 for item in msg.get("content", []): 1883 if isinstance(item, dict): 1884 # Tool result or other structured content 1885 if item.get("type") == "tool_result": 1886 result_content = item.get("content", "") 1887 if isinstance(result_content, str): 1888 num_tokens += len(self.encoding.encode(result_content)) 1889 else: 1890 # JSON serialization for dict/list content 1891 num_tokens += len(self.encoding.encode(json.dumps(result_content))) 1892 # Add tokens for tool_use_id and type fields 1893 num_tokens += len(self.encoding.encode(item.get("type", ""))) 1894 num_tokens += len(self.encoding.encode(item.get("tool_use_id", ""))) 1895 1896 # Text content type 1897 elif item.get("type") == "text": 1898 num_tokens += len(self.encoding.encode(item.get("text", ""))) 1899 1900 # Tool use type 1901 elif item.get("type") == "tool_use": 1902 num_tokens += len(self.encoding.encode(item.get("name", ""))) 1903 tool_input = item.get("input", {}) 1904 if isinstance(tool_input, str): 1905 num_tokens += len(self.encoding.encode(tool_input)) 1906 else: 1907 num_tokens += len(self.encoding.encode(json.dumps(tool_input))) 1908 num_tokens += len(self.encoding.encode(item.get("id", ""))) 1909 else: 1910 # Plain text content 1911 num_tokens += len(self.encoding.encode(str(item))) 1912 1913 # Count tokens in tool calls for OpenAI format 1914 if msg.get("tool_calls"): 1915 for tool_call in msg.get("tool_calls", []): 1916 if isinstance(tool_call, dict): 1917 # Count tokens for function name 1918 func_name = tool_call.get("function", {}).get("name", "") 1919 num_tokens += len(self.encoding.encode(func_name)) 1920 1921 # Count tokens for arguments 1922 args = tool_call.get("function", {}).get("arguments", "") 1923 if isinstance(args, str): 1924 num_tokens += len(self.encoding.encode(args)) 1925 else: 1926 num_tokens += len(self.encoding.encode(json.dumps(args))) 1927 1928 # Add tokens for id and type fields 1929 num_tokens += len(self.encoding.encode(tool_call.get("id", ""))) 1930 num_tokens += len(self.encoding.encode(tool_call.get("type", "function"))) 1931 1932 # Count tokens in Anthropic tool_use field 1933 if msg.get("tool_use"): 1934 tool_use = msg.get("tool_use") 1935 # Count tokens for name 1936 num_tokens += len(self.encoding.encode(tool_use.get("name", ""))) 1937 1938 # Count tokens for input 1939 tool_input = tool_use.get("input", {}) 1940 if isinstance(tool_input, str): 1941 num_tokens += len(self.encoding.encode(tool_input)) 1942 else: 1943 num_tokens += len(self.encoding.encode(json.dumps(tool_input))) 1944 1945 # Add tokens for id field 1946 num_tokens += len(self.encoding.encode(tool_use.get("id", ""))) 1947 1948 # Handle tool response message format 1949 if msg.get("role") == "tool": 1950 # Add tokens for tool_call_id 1951 tool_id = msg.get("tool_call_id", "") 1952 num_tokens += len(self.encoding.encode(tool_id)) 1953 1954 # Add message end tokens 1955 num_tokens += 2 1956 1957 # Cache the result for future use 1958 if use_cache and 'cache_key' in locals(): 1959 self._token_count_cache[cache_key] = num_tokens 1960 1961 return num_tokens 1962 1963 1964 def _cycle_messages(self): 1965 """Intelligently trim the message history to fit within the allowed context length. 1966 1967 This method implements a sophisticated trimming strategy that: 1968 1. Always preserves system messages 1969 2. Always keeps the most recent complete conversation turn 1970 3. Prioritizes keeping tool call chains intact 1971 4. Preserves important context from earlier exchanges 1972 5. Aggressively prunes redundant information before essential content 1973 1974 The method uses a multi-pass approach: 1975 1. First pass: Identify critical messages that must be kept 1976 2. Second pass: Calculate token counts for each message 1977 3. Third pass: Keep important tool chains intact 1978 4. Fourth pass: Fill remaining space with recent messages 1979 1980 Returns: 1981 bool: True if all messages were trimmed (context exceeded), 1982 False if trimming was successful and context is within limits 1983 1984 Note: 1985 This method maintains both session_history and history in sync, 1986 ensuring proper SDK-specific formatting is preserved. 1987 """ 1988 # Check if we need to trim 1989 token_count = self._count_tokens(self.history) 1990 1991 # If we're already under the limit, return early 1992 if token_count <= self.context_length: 1993 return False 1994 1995 self._log(f"[TRIM] Starting message cycling: {token_count} tokens exceeds {self.context_length} limit", level="info") 1996 1997 # We'll need to track tokens as we reconstruct the history 1998 target_tokens = max(self.context_length * 0.8, self.context_length - 1000) # Target 80% or 1000 less than max 1999 2000 # First pass: identify critical messages we must keep 2001 must_keep = [] 2002 tool_chain_groups = {} # Group related tool calls and their results 2003 2004 # Always keep system messages (should be first) 2005 system_indices = [] 2006 for i, msg in enumerate(self.history): 2007 if msg.get("role") == "system": 2008 system_indices.append(i) 2009 must_keep.append(i) 2010 2011 # Identify the most recent complete exchange (user question + assistant response) 2012 latest_exchange = [] 2013 # Start from the end and work backward to find the last complete exchange 2014 for i in range(len(self.history) - 1, -1, -1): 2015 msg = self.history[i] 2016 if msg.get("role") == "assistant" and not latest_exchange: 2017 latest_exchange.append(i) 2018 elif msg.get("role") == "user" and latest_exchange: 2019 latest_exchange.append(i) 2020 break 2021 2022 # Add the latest exchange to must-keep 2023 must_keep.extend(latest_exchange) 2024 2025 # Identify tool chains - track which messages belong to the same tool flow 2026 tool_id_to_chain = {} 2027 for i, msg in enumerate(self.history): 2028 # For assistant messages with tool calls 2029 if msg.get("role") == "assistant" and msg.get("tool_calls"): 2030 for tool_call in msg.get("tool_calls"): 2031 tool_id = tool_call.get("id") 2032 if tool_id: 2033 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2034 2035 # For tool response messages 2036 elif msg.get("role") == "tool" and msg.get("tool_call_id"): 2037 tool_id = msg.get("tool_call_id") 2038 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2039 2040 # For Anthropic format with tool use 2041 elif msg.get("role") == "assistant" and isinstance(msg.get("content"), list): 2042 for block in msg.get("content", []): 2043 if isinstance(block, dict) and block.get("type") == "tool_use": 2044 tool_id = block.get("id") 2045 if tool_id: 2046 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2047 2048 # For Anthropic tool result messages 2049 elif msg.get("role") == "user" and isinstance(msg.get("content"), list): 2050 for block in msg.get("content", []): 2051 if isinstance(block, dict) and block.get("type") == "tool_result": 2052 tool_id = block.get("tool_use_id") 2053 if tool_id: 2054 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2055 2056 # Group together all indices for each tool chain 2057 for tool_id, indices in tool_id_to_chain.items(): 2058 chain_key = f"tool_{min(indices)}" # Group by the earliest message 2059 if chain_key not in tool_chain_groups: 2060 tool_chain_groups[chain_key] = set() 2061 tool_chain_groups[chain_key].update(indices) 2062 2063 # Second pass: calculate tokens for each message 2064 message_tokens = [] 2065 for i, msg in enumerate(self.history): 2066 # Count tokens for this individual message 2067 tokens = self._count_tokens([msg]) 2068 message_tokens.append((i, tokens)) 2069 2070 # Keep the messages identified as must-keep 2071 keep_indices = set(must_keep) 2072 2073 # Calculate the tokens we've committed to keeping 2074 keep_tokens = sum(tokens for i, tokens in message_tokens if i in keep_indices) 2075 2076 # Check if we've already exceeded the target with just must-keep messages 2077 if keep_tokens > self.context_length: 2078 # We're in trouble - the essential messages alone exceed context 2079 # Drop older messages until we're under the limit 2080 all_indices = sorted(keep_indices) 2081 2082 # Start dropping oldest messages, but NEVER drop system messages 2083 for idx in all_indices: 2084 if idx not in system_indices: 2085 keep_indices.remove(idx) 2086 keep_tokens -= message_tokens[idx][1] 2087 if keep_tokens <= target_tokens: 2088 break 2089 2090 # If we've removed everything but system messages and still over limit 2091 if keep_tokens > self.context_length: 2092 self._log(f"[TRIM] Critical failure: even with minimal context ({keep_tokens} tokens), we exceed the limit", level="error") 2093 # Keep only system messages if any 2094 keep_indices = set(system_indices) 2095 return True # Context exceeded completely 2096 2097 # Third pass: keep the most important tool chains intact 2098 available_tokens = target_tokens - keep_tokens 2099 # Sort tool chains by recency (assumed by the chain_key which uses the earliest message) 2100 sorted_chains = sorted(tool_chain_groups.items(), key=lambda x: x[0], reverse=True) 2101 2102 for chain_key, indices in sorted_chains: 2103 # Skip if we've already decided to keep all messages in this chain 2104 if indices.issubset(keep_indices): 2105 continue 2106 2107 # Calculate how many tokens this chain would add 2108 chain_tokens = sum(tokens for i, tokens in message_tokens if i in indices and i not in keep_indices) 2109 2110 # If we can fit the entire chain, keep it 2111 if chain_tokens <= available_tokens: 2112 keep_indices.update(indices) 2113 available_tokens -= chain_tokens 2114 # Otherwise, we might want to keep partial chains in the future, but for now, skip 2115 2116 # Fourth pass: fill in with as many remaining messages as possible, prioritizing recency 2117 # Get remaining messages sorted by recency (newest first) 2118 remaining_indices = [(i, tokens) for i, tokens in message_tokens if i not in keep_indices] 2119 remaining_indices.sort(reverse=True) # Sort newest first 2120 2121 for i, tokens in remaining_indices: 2122 if tokens <= available_tokens: 2123 keep_indices.add(i) 2124 available_tokens -= tokens 2125 2126 # Final message reconstruction 2127 self._log(f"[TRIM] Keeping {len(keep_indices)}/{len(self.history)} messages, estimated {target_tokens - available_tokens} tokens", level="info") 2128 2129 # Create new history with just the kept messages, preserving order 2130 new_history = [self.history[i] for i in sorted(keep_indices)] 2131 self.history = new_history 2132 2133 # Update session_history to match the pruned history 2134 if hasattr(self, "session_history"): 2135 # Map between history items and session_history 2136 session_to_keep = [] 2137 2138 # For each session history message, check if it corresponds to a kept message 2139 for session_msg in self.session_history: 2140 # Keep system messages 2141 if session_msg.get("role") == "system": 2142 session_to_keep.append(session_msg) 2143 continue 2144 2145 # Try to match based on available IDs or content 2146 msg_id = session_msg.get("id") 2147 2148 # For tool messages, check tool_info.id against tool_call_id 2149 if "metadata" in session_msg and "tool_info" in session_msg["metadata"]: 2150 tool_id = session_msg["metadata"]["tool_info"].get("id") 2151 2152 # Check if this tool_id is still in the kept history 2153 for history_msg in new_history: 2154 # Check standard ids 2155 history_tool_id = None 2156 2157 # Check OpenAI format 2158 if history_msg.get("role") == "tool": 2159 history_tool_id = history_msg.get("tool_call_id") 2160 elif history_msg.get("role") == "assistant" and history_msg.get("tool_calls"): 2161 for call in history_msg.get("tool_calls", []): 2162 if call.get("id") == tool_id: 2163 history_tool_id = call.get("id") 2164 break 2165 2166 # Check Anthropic format 2167 elif isinstance(history_msg.get("content"), list): 2168 for block in history_msg.get("content", []): 2169 if isinstance(block, dict): 2170 if block.get("type") == "tool_use" and block.get("id") == tool_id: 2171 history_tool_id = block.get("id") 2172 break 2173 elif block.get("type") == "tool_result" and block.get("tool_use_id") == tool_id: 2174 history_tool_id = block.get("tool_use_id") 2175 break 2176 2177 if history_tool_id == tool_id: 2178 session_to_keep.append(session_msg) 2179 break 2180 2181 # For regular messages, try content matching as fallback 2182 else: 2183 content_match = False 2184 if isinstance(session_msg.get("content"), str) and session_msg.get("content"): 2185 for history_msg in new_history: 2186 if history_msg.get("role") == session_msg.get("role") and history_msg.get("content") == session_msg.get("content"): 2187 content_match = True 2188 break 2189 2190 if content_match: 2191 session_to_keep.append(session_msg) 2192 2193 # Update session_history with kept messages 2194 self.session_history = session_to_keep 2195 2196 # Re-normalize to ensure consistency 2197 self._normalizer(force=True) 2198 2199 # Verify our final token count 2200 final_token_count = self._count_tokens(self.history) 2201 self._log(f"[TRIM] Final history has {len(self.history)} messages, {final_token_count} tokens", level="info") 2202 2203 # Return whether we've completely exceeded context 2204 return final_token_count > self.context_length or len(self.history) == len(system_indices) 2205 2206 2207 def messages_add( 2208 self, 2209 role: str, 2210 content: Any, 2211 tool_info: Optional[Dict] = None, 2212 normalize: bool = True 2213 ) -> str: 2214 """ 2215 Add a message to the standardized session_history and then update SDK-specific history. 2216 2217 This method is the central point for all message additions to the conversation. 2218 2219 Args: 2220 role: The role of the message ("user", "assistant", "system", "tool") 2221 content: The message content (text or structured) 2222 tool_info: Optional tool-related metadata 2223 normalize: Whether to normalize history after adding this message 2224 2225 Returns: 2226 str: Unique ID of the added message 2227 """ 2228 # Generate a unique message ID 2229 message_id = str(uuid.uuid4()) 2230 2231 # Create the standardized message for session_history 2232 timestamp = datetime.now(timezone.utc).isoformat() 2233 2234 # Store system messages directly 2235 if role == "system": 2236 self.system = content 2237 2238 # Create standard format message 2239 standard_message = { 2240 "role": role, 2241 "content": content, 2242 "id": message_id, 2243 "timestamp": timestamp, 2244 "metadata": { 2245 "sdk": self.sdk 2246 } 2247 } 2248 2249 # Add tool info if provided 2250 if tool_info: 2251 standard_message["metadata"]["tool_info"] = tool_info 2252 2253 # Add to session_history 2254 if not hasattr(self, "session_history"): 2255 self.session_history = [] 2256 2257 self.session_history.append(standard_message) 2258 2259 # Save to persistent session if enabled 2260 if self.session_enabled and self.session_id: 2261 # Convert standard message to session-compatible format 2262 session_msg = { 2263 "role": role, 2264 "content": content, 2265 "id": message_id, 2266 "timestamp": timestamp 2267 } 2268 2269 # Add tool-related fields if present 2270 if tool_info: 2271 for key, value in tool_info.items(): 2272 session_msg[key] = value 2273 2274 # Store in session 2275 self.session.msg_insert(self.session_id, session_msg) 2276 2277 # Update the SDK-specific format in self.history by running the normalizer 2278 if normalize: 2279 # We only need to normalize the most recent message for efficiency 2280 # Pass a flag indicating we're just normalizing a new message 2281 self._normalizer(force=False, new_message_only=True) 2282 2283 # Log the added message 2284 self._log(f"[MESSAGE ADDED] {role}: {str(content)[:50]}...") 2285 2286 return message_id 2287 2288 2289 def messages_system(self, prompt: str): 2290 """Set or retrieve the current system prompt. 2291 2292 This method manages the system prompt that guides the AI's behavior. 2293 It can be used to both set a new system prompt and retrieve the current one. 2294 When setting a new prompt, it updates the system message in the conversation 2295 history and persists it to the session if enabled. 2296 2297 Args: 2298 prompt (str): The new system prompt to set. If empty or None, 2299 returns the current system prompt without changes. 2300 2301 Returns: 2302 str: The current system prompt after any updates 2303 """ 2304 if not isinstance(prompt, str) or not prompt: 2305 return self.system 2306 2307 # If the prompt hasn't changed, don't do anything 2308 if self.system == prompt: 2309 return self.system 2310 2311 # Update the system prompt 2312 old_system = self.system 2313 self.system = prompt 2314 2315 # For OpenAI, update or insert the system message in history 2316 if self.sdk == "openai": 2317 # Check if there's already a system message 2318 system_index = next((i for i, msg in enumerate(self.history) 2319 if msg.get("role") == "system"), None) 2320 2321 if system_index is not None: 2322 # Update existing system message 2323 self.history[system_index]["content"] = prompt 2324 else: 2325 # Insert new system message at the beginning 2326 self.history.insert(0, {"role": "system", "content": prompt}) 2327 2328 # For Anthropic, system message is not part of history, just save it for API calls 2329 2330 # Log to session only if prompt actually changed 2331 if self.session_enabled and self.session_id and old_system != prompt: 2332 self.session.msg_insert(self.session_id, {"role": "system", "content": prompt}) 2333 2334 return self.system 2335 2336 2337 def messages(self) -> list: 2338 """Return full session messages (persisted or in-memory).""" 2339 if self.session_enabled and self.session_id: 2340 return self.session.load_full(self.session_id).get("messages", []) 2341 return self.session_history 2342 2343 2344 def messages_length(self) -> int: 2345 """Calculate the total token count for the message history.""" 2346 if not self.encoding: 2347 return 0 2348 2349 total_tokens = 0 2350 for message in self.history: 2351 if message.get("content"): 2352 total_tokens += len(self.encoding.encode(message["content"])) 2353 if message.get("tool_calls"): 2354 for tool_call in message["tool_calls"]: 2355 if tool_call.get("function"): 2356 total_tokens += len(self.encoding.encode(tool_call["function"].get("name", ""))) 2357 total_tokens += len(self.encoding.encode(tool_call["function"].get("arguments", ""))) 2358 return total_tokens 2359 2360 2361 def session_load(self, session_id: Optional[str]): 2362 """Load and normalize messages for a specific session. 2363 2364 This method loads a conversation session from persistent storage and 2365 normalizes the messages to the current SDK format. It handles system 2366 messages, tool calls, and maintains message ordering. If loading fails, 2367 it resets to an empty session with the default system prompt. 2368 2369 Args: 2370 session_id (Optional[str]): The ID of the session to load. 2371 If None, resets to in-memory mode. 2372 2373 Note: 2374 This method will update both session_history and history to match 2375 the loaded session's state. It also ensures proper SDK-specific 2376 message formatting. 2377 """ 2378 self.session_id = session_id 2379 self._last_session_id = session_id 2380 2381 if self.session_enabled and session_id: 2382 try: 2383 # Load raw session data 2384 session_data = self.session.load_full(session_id) 2385 messages = session_data.get("messages", []) 2386 2387 # Convert session format to our standard format 2388 self.session_history = [] 2389 2390 # Track the most recent system message 2391 latest_system_msg = None 2392 2393 for msg in messages: 2394 # Extract fields 2395 role = msg.get("role", "user") 2396 content = msg.get("content", "") 2397 msg_id = msg.get("id", str(uuid.uuid4())) 2398 timestamp = msg.get("timestamp", datetime.now(timezone.utc).isoformat()) 2399 2400 # If this is a system message, track it but don't add to session_history yet 2401 if role == "system": 2402 if latest_system_msg is None or timestamp > latest_system_msg["timestamp"]: 2403 latest_system_msg = { 2404 "role": role, 2405 "content": content, 2406 "id": msg_id, 2407 "timestamp": timestamp, 2408 "metadata": {"sdk": self.sdk} 2409 } 2410 continue 2411 2412 # Build tool_info if present 2413 tool_info = None 2414 if any(key in msg for key in ["tool_use_id", "tool_call_id", "name", "arguments"]): 2415 tool_info = { 2416 "id": msg.get("tool_use_id") or msg.get("tool_call_id"), 2417 "name": msg.get("name", "unknown_tool"), 2418 "arguments": msg.get("arguments", {}) 2419 } 2420 2421 # Create standard message 2422 standard_msg = { 2423 "role": role, 2424 "content": content, 2425 "id": msg_id, 2426 "timestamp": timestamp, 2427 "metadata": { 2428 "sdk": self.sdk 2429 } 2430 } 2431 2432 if tool_info: 2433 standard_msg["metadata"]["tool_info"] = tool_info 2434 2435 self.session_history.append(standard_msg) 2436 2437 # If we found a system message, update the system property and add to history 2438 if latest_system_msg: 2439 self.system = latest_system_msg["content"] 2440 # Insert at the beginning of session_history 2441 self.session_history.insert(0, latest_system_msg) 2442 else: 2443 # If no system message was found, add the current system message 2444 self.messages_add(role="system", content=self.system) 2445 2446 # Normalize to current SDK format 2447 self._normalizer(force=True) 2448 2449 self._log(f"[SESSION] Switched to session '{session_id}'") 2450 except Exception as e: 2451 self.logger.error(f"Failed to load session '{session_id}': {e}") 2452 self.session_reset() 2453 else: 2454 # Reset to empty state with system message 2455 self.session_reset() 2456 2457 2458 def session_reset(self): 2459 """Reset the current session state and reinitialize to default system prompt. 2460 2461 This method performs a complete reset of the conversation state: 2462 1. Clears all message history 2463 2. Disables session ID tracking 2464 3. Returns to in-memory mode 2465 4. Reinitializes with the default system prompt 2466 2467 The reset is useful for starting fresh conversations or recovering 2468 from error states. It maintains the basic system configuration while 2469 clearing all conversation context. 2470 """ 2471 self.session_id = None 2472 self._last_session_id = None 2473 2474 # Clear histories 2475 self.session_history = [] 2476 self.history = [] 2477 2478 # Reapply the system message 2479 if hasattr(self, "system") and self.system: 2480 # Add to session_history 2481 self.messages_add(role="system", content=self.system) 2482 else: 2483 # Ensure we have a default system message 2484 self.system = "You are a helpful Assistant." 2485 self.messages_add(role="system", content=self.system) 2486 2487 self._log("[SESSION] Reset to in-memory mode") 2488 2489 2490 def _normalizer(self, force=False, new_message_only=False): 2491 """Central normalization function for message format conversion. 2492 2493 This method transforms the standardized session_history into the 2494 SDK-specific format needed in self.history. It handles different 2495 message types (system, user, assistant, tool) and their various 2496 formats across different SDKs. 2497 2498 Args: 2499 force (bool): If True, always normalize even if SDK hasn't changed. 2500 Default is False, which only normalizes on SDK change. 2501 new_message_only (bool): If True, only normalize the most recent message 2502 for efficiency when adding single messages. 2503 2504 Note: 2505 This method is the central point for message format conversion and 2506 ensures consistency between session storage and API communication. 2507 """ 2508 # Skip normalization if SDK hasn't changed and force is False 2509 if not force and hasattr(self, '_last_sdk') and self._last_sdk == self.sdk: 2510 # If we only need to normalize the most recent message 2511 if new_message_only and self.session_history: 2512 # Get the most recent message from session_history 2513 recent_msg = self.session_history[-1] 2514 2515 # Apply SDK-specific normalization for just this message 2516 if self.sdk == "openai": 2517 self._openai_normalize_message(recent_msg) 2518 elif self.sdk == "anthropic": 2519 self._anthropic_normalize_message(recent_msg) 2520 else: 2521 # Generic handler for unknown SDKs 2522 self._generic_normalize_message(recent_msg) 2523 2524 return 2525 2526 # Record the current SDK to detect future changes 2527 self._last_sdk = self.sdk 2528 2529 # For full normalization, clear current history and rebuild it 2530 self.history = [] 2531 2532 # Call the appropriate SDK-specific normalizer 2533 if self.sdk == "openai": 2534 self._openai_normalizer() 2535 elif self.sdk == "anthropic": 2536 self._anthropic_normalizer() 2537 else: 2538 self.logger.warning(f"No normalizer available for SDK: {self.sdk}") 2539 # Fallback to a simple conversion for unknown SDKs 2540 for msg in self.session_history: 2541 self._generic_normalize_message(msg) 2542 2543 2544 def _openai_normalizer(self): 2545 """Convert standardized session_history to OpenAI-compatible format. 2546 2547 This method transforms the internal message format into the structure 2548 required by the OpenAI API. It handles: 2549 - System messages at the start of history 2550 - User messages with plain text 2551 - Assistant messages with optional tool calls 2552 - Tool response messages with tool_call_id 2553 2554 The resulting format matches OpenAI's chat completion API requirements 2555 for both regular messages and function calling. 2556 """ 2557 # For OpenAI, we need to include system message in the history 2558 # and convert tool calls/results to OpenAI format 2559 2560 # Start with empty history 2561 self.history = [] 2562 2563 # First, add the current system message at position 0 2564 self.history.append({ 2565 "role": "system", 2566 "content": self.system 2567 }) 2568 2569 # Process all non-system messages 2570 for msg in self.session_history: 2571 if msg["role"] == "system": 2572 continue # Skip system messages, already handled 2573 2574 # Handle different message types 2575 if msg["role"] == "user": 2576 # User messages are straightforward 2577 self.history.append({ 2578 "role": "user", 2579 "content": msg["content"] 2580 }) 2581 2582 elif msg["role"] == "assistant": 2583 # For assistant messages with tool calls 2584 if "metadata" in msg and msg["metadata"].get("tool_info"): 2585 # This is an assistant message with tool calls 2586 tool_info = msg["metadata"]["tool_info"] 2587 2588 # Create OpenAI assistant message with tool calls 2589 assistant_msg = { 2590 "role": "assistant", 2591 "content": msg["content"] if isinstance(msg["content"], str) else "", 2592 "tool_calls": [{ 2593 "id": tool_info["id"], 2594 "type": "function", 2595 "function": { 2596 "name": tool_info["name"], 2597 "arguments": json.dumps(tool_info["arguments"]) if isinstance(tool_info["arguments"], dict) else tool_info["arguments"] 2598 } 2599 }] 2600 } 2601 self.history.append(assistant_msg) 2602 else: 2603 # Regular assistant message 2604 self.history.append({ 2605 "role": "assistant", 2606 "content": msg["content"] 2607 }) 2608 2609 elif msg["role"] == "tool": 2610 # Tool response messages 2611 if "metadata" in msg and "tool_info" in msg["metadata"]: 2612 tool_msg = { 2613 "role": "tool", 2614 "tool_call_id": msg["metadata"]["tool_info"]["id"], 2615 "content": json.dumps(msg["content"]) if isinstance(msg["content"], (dict, list)) else msg["content"] 2616 } 2617 self.history.append(tool_msg) 2618 2619 2620 def _anthropic_normalizer(self): 2621 """Convert standardized session_history to Anthropic-compatible format. 2622 2623 This method transforms the internal message format into the structure 2624 required by the Anthropic API. It handles: 2625 - System messages (stored separately, not in history) 2626 - User messages with optional tool results 2627 - Assistant messages with optional tool use 2628 - Content blocks for structured responses 2629 2630 The resulting format matches Anthropic's message API requirements 2631 for both regular messages and tool use. 2632 """ 2633 # For Anthropic, we don't include system message in the history 2634 # but need to handle content blocks for tool use/results 2635 2636 # Start with empty history 2637 self.history = [] 2638 2639 # Process all non-system messages 2640 for msg in self.session_history: 2641 if msg["role"] == "system": 2642 # Update system prompt if this is the most recent system message 2643 # (only apply the most recent system message if we have multiple) 2644 if msg == self.session_history[-1] or all(m["role"] != "system" for m in self.session_history[self.session_history.index(msg)+1:]): 2645 self.system = msg["content"] 2646 continue # Skip system messages in history 2647 2648 # Handle different message types 2649 if msg["role"] == "user": 2650 # User messages - check if it contains tool results 2651 if "metadata" in msg and "tool_info" in msg["metadata"] and msg["metadata"]["tool_info"].get("result"): 2652 # This is a tool result message 2653 tool_info = msg["metadata"]["tool_info"] 2654 2655 # Create Anthropic tool result format 2656 tool_result_msg = { 2657 "role": "user", 2658 "content": [{ 2659 "type": "tool_result", 2660 "tool_use_id": tool_info["id"], 2661 "content": json.dumps(tool_info["result"]) if isinstance(tool_info["result"], (dict, list)) else str(tool_info["result"]) 2662 }] 2663 } 2664 self.history.append(tool_result_msg) 2665 else: 2666 # Regular user message 2667 self.history.append({ 2668 "role": "user", 2669 "content": msg["content"] 2670 }) 2671 2672 elif msg["role"] == "assistant": 2673 # For assistant messages, check for tool use 2674 if "metadata" in msg and "tool_info" in msg["metadata"]: 2675 # This is an assistant message with tool use 2676 tool_info = msg["metadata"]["tool_info"] 2677 2678 # Build content blocks 2679 content_blocks = [] 2680 2681 # Add text content if present 2682 if msg["content"]: 2683 content_blocks.append({ 2684 "type": "text", 2685 "text": msg["content"] if isinstance(msg["content"], str) else "" 2686 }) 2687 2688 # Add tool use block 2689 content_blocks.append({ 2690 "type": "tool_use", 2691 "id": tool_info["id"], 2692 "name": tool_info["name"], 2693 "input": tool_info["arguments"] if isinstance(tool_info["arguments"], dict) else json.loads(tool_info["arguments"]) 2694 }) 2695 2696 # Create Anthropic assistant message with tool use 2697 self.history.append({ 2698 "role": "assistant", 2699 "content": content_blocks 2700 }) 2701 else: 2702 # Regular assistant message 2703 self.history.append({ 2704 "role": "assistant", 2705 "content": msg["content"] 2706 }) 2707 2708 elif msg["role"] == "tool": 2709 # Tool messages in standard format get converted to user messages with tool_result 2710 if "metadata" in msg and "tool_info" in msg["metadata"]: 2711 tool_info = msg["metadata"]["tool_info"] 2712 2713 # Create Anthropic tool result message 2714 tool_result_msg = { 2715 "role": "user", 2716 "content": [{ 2717 "type": "tool_result", 2718 "tool_use_id": tool_info["id"], 2719 "content": json.dumps(msg["content"]) if isinstance(msg["content"], (dict, list)) else str(msg["content"]) 2720 }] 2721 } 2722 self.history.append(tool_result_msg) 2723 2724 2725 def _openai_normalize_message(self, msg): 2726 """Normalize a single message to OpenAI format and add to history.""" 2727 role = msg.get("role") 2728 content = msg.get("content") 2729 2730 if role == "system": 2731 # Check if we already have a system message in history 2732 system_index = next((i for i, m in enumerate(self.history) 2733 if m.get("role") == "system"), None) 2734 if system_index is not None: 2735 # Update existing system message 2736 self.history[system_index]["content"] = content 2737 else: 2738 # Insert new system message at the beginning 2739 self.history.insert(0, { 2740 "role": "system", 2741 "content": content 2742 }) 2743 # Update the system property 2744 self.system = content 2745 2746 elif role == "user": 2747 self.history.append({ 2748 "role": "user", 2749 "content": content 2750 }) 2751 2752 elif role == "assistant": 2753 # For assistant messages, handle potential tool calls 2754 if "metadata" in msg and msg["metadata"].get("tool_info"): 2755 # This is an assistant message with tool calls 2756 tool_info = msg["metadata"]["tool_info"] 2757 2758 # Create OpenAI assistant message with tool calls 2759 try: 2760 arguments = tool_info.get("arguments", {}) 2761 arguments_str = json.dumps(arguments) if isinstance(arguments, dict) else arguments 2762 except: 2763 arguments_str = str(arguments) 2764 2765 assistant_msg = { 2766 "role": "assistant", 2767 "content": content if isinstance(content, str) else "", 2768 "tool_calls": [{ 2769 "id": tool_info["id"], 2770 "type": "function", 2771 "function": { 2772 "name": tool_info["name"], 2773 "arguments": arguments_str 2774 } 2775 }] 2776 } 2777 self.history.append(assistant_msg) 2778 else: 2779 # Regular assistant message 2780 self.history.append({ 2781 "role": "assistant", 2782 "content": content 2783 }) 2784 2785 elif role == "tool": 2786 # Tool response messages 2787 if "metadata" in msg and "tool_info" in msg["metadata"]: 2788 tool_info = msg["metadata"]["tool_info"] 2789 tool_msg = { 2790 "role": "tool", 2791 "tool_call_id": tool_info["id"], 2792 "content": json.dumps(content) if isinstance(content, (dict, list)) else str(content) 2793 } 2794 self.history.append(tool_msg) 2795 2796 2797 def _anthropic_normalize_message(self, msg): 2798 """Normalize a single message to Anthropic format and add to history.""" 2799 role = msg.get("role") 2800 content = msg.get("content") 2801 2802 if role == "system": 2803 # Store system prompt separately, not in history for Anthropic 2804 self.system = content 2805 2806 elif role == "user": 2807 # User messages - check if it contains tool results 2808 if "metadata" in msg and "tool_info" in msg["metadata"]: 2809 tool_info = msg["metadata"]["tool_info"] 2810 # Check for result or directly use content 2811 result_content = tool_info.get("result", content) 2812 2813 # Create Anthropic tool result format 2814 try: 2815 result_str = json.dumps(result_content) if isinstance(result_content, (dict, list)) else str(result_content) 2816 except: 2817 result_str = str(result_content) 2818 2819 tool_result_msg = { 2820 "role": "user", 2821 "content": [{ 2822 "type": "tool_result", 2823 "tool_use_id": tool_info["id"], 2824 "content": result_str 2825 }] 2826 } 2827 self.history.append(tool_result_msg) 2828 else: 2829 # Regular user message 2830 self.history.append({ 2831 "role": "user", 2832 "content": content 2833 }) 2834 2835 elif role == "assistant": 2836 # For assistant messages, check for tool use 2837 if "metadata" in msg and "tool_info" in msg["metadata"]: 2838 # This is an assistant message with tool use 2839 tool_info = msg["metadata"]["tool_info"] 2840 2841 # Build content blocks 2842 content_blocks = [] 2843 2844 # Add text content if present 2845 if content: 2846 content_blocks.append({ 2847 "type": "text", 2848 "text": content if isinstance(content, str) else "" 2849 }) 2850 2851 # Add tool use block - safely convert arguments 2852 try: 2853 # Parse arguments to ensure it's a dictionary 2854 if isinstance(tool_info["arguments"], str): 2855 try: 2856 args_dict = json.loads(tool_info["arguments"]) 2857 except json.JSONDecodeError: 2858 args_dict = {"text": tool_info["arguments"]} 2859 else: 2860 args_dict = tool_info["arguments"] 2861 except: 2862 args_dict = {"error": "Failed to parse arguments"} 2863 2864 content_blocks.append({ 2865 "type": "tool_use", 2866 "id": tool_info["id"], 2867 "name": tool_info["name"], 2868 "input": args_dict 2869 }) 2870 2871 # Create Anthropic assistant message with tool use 2872 self.history.append({ 2873 "role": "assistant", 2874 "content": content_blocks 2875 }) 2876 else: 2877 # Regular assistant message 2878 self.history.append({ 2879 "role": "assistant", 2880 "content": content 2881 }) 2882 2883 elif role == "tool": 2884 # Tool messages in standard format get converted to user messages with tool_result 2885 if "metadata" in msg and "tool_info" in msg["metadata"]: 2886 tool_info = msg["metadata"]["tool_info"] 2887 2888 try: 2889 result_str = json.dumps(content) if isinstance(content, (dict, list)) else str(content) 2890 except: 2891 result_str = str(content) 2892 2893 # Create Anthropic tool result message 2894 tool_result_msg = { 2895 "role": "user", 2896 "content": [{ 2897 "type": "tool_result", 2898 "tool_use_id": tool_info["id"], 2899 "content": result_str 2900 }] 2901 } 2902 self.history.append(tool_result_msg) 2903 2904 2905 def _generic_normalize_message(self, msg): 2906 """Generic normalizer for unknown SDKs. 2907 2908 This method provides a basic message normalization for SDKs that 2909 don't have specific handling. It performs minimal conversion to 2910 ensure basic message structure is maintained. 2911 2912 Args: 2913 msg (dict): The message to normalize, containing at minimum: 2914 - role: Message role (user/assistant/system) 2915 - content: Message content 2916 2917 Note: 2918 This is a fallback method and should be overridden for specific 2919 SDK implementations when possible. 2920 """ 2921 role = msg.get("role") 2922 content = msg.get("content") 2923 2924 if role in ["user", "assistant", "system"]: 2925 self.history.append({ 2926 "role": role, 2927 "content": content 2928 }) 2929 2930 2931 def track_token_usage(self): 2932 """Track and return token usage across the conversation history. 2933 2934 This method maintains a history of token usage measurements and provides 2935 current usage statistics. It tracks: 2936 - Current token count 2937 - Context length limit 2938 - Usage percentage 2939 - Historical measurements 2940 - Current provider and model info 2941 2942 Returns: 2943 dict: Dictionary containing: 2944 - current: Current token count 2945 - limit: Maximum context length 2946 - percentage: Usage as percentage of limit 2947 - history: Last 10 measurements with timestamps 2948 - provider: Current provider name 2949 - model: Current model name 2950 2951 Note: 2952 The history is limited to the last 100 measurements to prevent 2953 unbounded memory growth. 2954 """ 2955 if not hasattr(self, "_token_history"): 2956 self._token_history = [] 2957 2958 # Count current tokens 2959 current_count = self._count_tokens(self.history) 2960 2961 # Add to history 2962 timestamp = datetime.now(timezone.utc).isoformat() 2963 self._token_history.append({ 2964 "timestamp": timestamp, 2965 "count": current_count, 2966 "limit": self.context_length, 2967 "provider": self.provider, 2968 "model": self.model 2969 }) 2970 2971 # Keep only the last 100 measurements to avoid unlimited growth 2972 if len(self._token_history) > 100: 2973 self._token_history = self._token_history[-100:] 2974 2975 # Return current tracking info 2976 return { 2977 "current": current_count, 2978 "limit": self.context_length, 2979 "percentage": round((current_count / self.context_length) * 100, 1) if self.context_length else 0, 2980 "history": self._token_history[-10:], # Return last 10 measurements 2981 "provider": self.provider, 2982 "model": self.model 2983 } 2984 2985 2986 def get_message_token_breakdown(self): 2987 """Analyze token usage by message type and provide a detailed breakdown. 2988 2989 This method performs a detailed analysis of token usage across the 2990 conversation history, breaking down usage by: 2991 - Message role (system, user, assistant, tool) 2992 - Content type (text, tool calls, tool results) 2993 - Individual message statistics 2994 2995 Returns: 2996 dict: Token usage breakdown containing: 2997 - total: Total tokens used 2998 - by_role: Tokens used by each role 2999 - by_type: Tokens used by content type 3000 - messages: List of individual message stats including: 3001 - index: Message position 3002 - role: Message role 3003 - tokens: Tokens used 3004 - has_tools: Whether message contains tool calls 3005 3006 Note: 3007 This analysis is useful for understanding token usage patterns 3008 and optimizing conversation context. 3009 """ 3010 breakdown = { 3011 "total": 0, 3012 "by_role": { 3013 "system": 0, 3014 "user": 0, 3015 "assistant": 0, 3016 "tool": 0 3017 }, 3018 "by_type": { 3019 "text": 0, 3020 "tool_calls": 0, 3021 "tool_results": 0 3022 }, 3023 "messages": [] 3024 } 3025 3026 # Analyze each message 3027 for i, msg in enumerate(self.history): 3028 msg_tokens = self._count_tokens([msg]) 3029 role = msg.get("role", "unknown") 3030 3031 # Track by role 3032 if role in breakdown["by_role"]: 3033 breakdown["by_role"][role] += msg_tokens 3034 3035 # Track by content type 3036 if role == "assistant" and msg.get("tool_calls"): 3037 breakdown["by_type"]["tool_calls"] += msg_tokens 3038 elif role == "tool": 3039 breakdown["by_type"]["tool_results"] += msg_tokens 3040 else: 3041 breakdown["by_type"]["text"] += msg_tokens 3042 3043 # Add individual message data 3044 breakdown["messages"].append({ 3045 "index": i, 3046 "role": role, 3047 "tokens": msg_tokens, 3048 "has_tools": bool(msg.get("tool_calls") or msg.get("tool_use") or 3049 (isinstance(msg.get("content"), list) and 3050 any(isinstance(c, dict) and c.get("type") in ["tool_use", "tool_result"] 3051 for c in msg.get("content", [])))) 3052 }) 3053 3054 # Update tota? 3055 breakdown["total"] += msg_tokens 3056 3057 return breakdown
65 def __init__( 66 self, 67 base_url: Optional[str] = None, 68 api_key: Optional[str] = None, 69 model: str = "openai:gpt-4o-mini", 70 fallback_model = "ollama:mistral-nemo:latest", 71 tools: Optional[bool] = True, 72 stream: bool = True, 73 quiet: bool = False, 74 context_length: int = 128000, 75 max_retries: int = 3, 76 retry_delay: float = 1.0, 77 log_path: Optional[str] = None, 78 log_level: Optional[str] = "ERROR", 79 raw: Optional[bool] = False, 80 session_enabled: bool = False, 81 session_id: Optional[str] = None, 82 session_path: Optional[str] = None 83 ): 84 """Initialize the universal AI interaction client. 85 86 Args: 87 base_url: Optional base URL for the API. If None, uses the provider's default URL. 88 api_key: Optional API key. If None, attempts to use environment variables based on provider. 89 model: Model identifier in format "provider:model_name" (e.g., "openai:gpt-4o-mini"). 90 tools: Enable (True) or disable (False) tool calling; None for auto-detection based on model support. 91 stream: Enable (True) or disable (False) streaming responses. 92 context_length: Maximum number of tokens to maintain in conversation history. 93 max_retries: Maximum number of retries for failed API calls. 94 retry_delay: Initial delay (in seconds) for exponential backoff retries. 95 session_enabled: Enable persistent session support. 96 session_id: Optional session ID to load messages from. 97 98 Raises: 99 ValueError: If provider is not supported or API key is missing for non-Ollama providers. 100 """ 101 self.system = "You are a helpful Assistant." 102 self.raw = raw 103 self.quiet = quiet 104 self.models = [] 105 106 # Configure Logging 107 level_map = { 108 "DEBUG": logging.DEBUG, 109 "INFO": logging.INFO, 110 "WARNING": logging.WARNING, 111 "ERROR": logging.ERROR, 112 "CRITICAL": logging.CRITICAL 113 } 114 115 self.logger = logging.getLogger(f"InteractorLogger_{id(self)}") 116 self.logger.setLevel(level_map[log_level]) 117 118 self.providers = { 119 "openai": { 120 "sdk": "openai", 121 "base_url": "https://api.openai.com/v1", 122 "api_key": api_key or os.getenv("OPENAI_API_KEY") or None 123 }, 124 "ollama": { 125 "sdk": "openai", 126 "base_url": "http://localhost:11434/v1", 127 "api_key": api_key or "ollama" 128 }, 129 "nvidia": { 130 "sdk": "openai", 131 "base_url": "https://integrate.api.nvidia.com/v1", 132 "api_key": api_key or os.getenv("NVIDIA_API_KEY") or None 133 }, 134 "google": { 135 "sdk": "openai", 136 "base_url": "https://generativelanguage.googleapis.com/v1beta/openai", 137 "api_key": api_key or os.getenv("GEMINI_API_KEY") or None 138 }, 139 "anthropic": { 140 "sdk": "anthropic", 141 "base_url": "https://api.anthropic.com/v1", 142 "api_key": api_key or os.getenv("ANTHROPIC_API_KEY") or None 143 }, 144 "mistral": { 145 "sdk": "openai", 146 "base_url": "https://api.mistral.ai/v1", 147 "api_key": api_key or os.getenv("MISTRAL_API_KEY") or None 148 }, 149 "deepseek": { 150 "sdk": "openai", 151 "base_url": "https://api.deepseek.com", 152 "api_key": api_key or os.getenv("DEEPSEEK_API_KEY") or None 153 }, 154 } 155 """ 156 "grok": { 157 "sdk": "grok", 158 "base_url": "https://api.x.ai/v1", 159 "api_key": api_key or os.getenv("GROK_API_KEY") or None 160 } 161 } 162 """ 163 164 # Console log handler (always enabled at WARNING+) 165 if not self.logger.handlers: 166 console_handler = logging.StreamHandler(sys.stdout) 167 console_handler.setLevel(logging.WARNING) 168 console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) 169 self.logger.addHandler(console_handler) 170 171 self._log_enabled = False 172 if log_path: 173 file_handler = logging.FileHandler(log_path) 174 file_handler.setLevel(logging.DEBUG) 175 file_handler.setFormatter(logging.Formatter( 176 "%(asctime)s - %(levelname)s - %(message)s", 177 datefmt="%Y-%m-%d %H:%M:%S" 178 )) 179 self.logger.addHandler(file_handler) 180 self._log_enabled = True 181 182 183 self.token_estimate = 0 184 self.last_token_estimate = 0 185 self.stream = stream 186 self.tools = [] 187 self.session_history = [] 188 self.history = [] 189 self.context_length = context_length 190 self.encoding = None 191 self.max_retries = max_retries 192 self.retry_delay = retry_delay 193 self.reveal_tool = [] 194 self.fallback_model = fallback_model 195 self.sdk = None 196 197 # Session support 198 self.session_enabled = session_enabled 199 self.session_id = session_id 200 self._last_session_id = session_id 201 self.session = Session(directory=session_path) if session_enabled else None 202 203 204 if model is None: 205 model = "openai:gpt-4o-mini" 206 207 # Initialize model + encoding 208 self._setup_client(model, base_url, api_key) 209 self.tools_enabled = self.tools_supported if tools is None else tools and self.tools_supported 210 self._setup_encoding() 211 self.messages_add(role="system", content=self.system)
Initialize the universal AI interaction client.
Args: base_url: Optional base URL for the API. If None, uses the provider's default URL. api_key: Optional API key. If None, attempts to use environment variables based on provider. model: Model identifier in format "provider:model_name" (e.g., "openai:gpt-4o-mini"). tools: Enable (True) or disable (False) tool calling; None for auto-detection based on model support. stream: Enable (True) or disable (False) streaming responses. context_length: Maximum number of tokens to maintain in conversation history. max_retries: Maximum number of retries for failed API calls. retry_delay: Initial delay (in seconds) for exponential backoff retries. session_enabled: Enable persistent session support. session_id: Optional session ID to load messages from.
Raises: ValueError: If provider is not supported or API key is missing for non-Ollama providers.
"grok": { "sdk": "grok", "base_url": "https://api.x.ai/v1", "api_key": api_key or os.getenv("GROK_API_KEY") or None } }
384 def add_function( 385 self, 386 external_callable: Callable, 387 name: Optional[str] = None, 388 description: Optional[str] = None, 389 override: bool = False, 390 disabled: bool = False, 391 schema_extensions: Optional[Dict[str, Any]] = None 392 ): 393 """ 394 Register a function for LLM tool calling with full type hints and metadata. 395 396 Args: 397 external_callable (Callable): The function to register. 398 name (Optional[str]): Optional custom name. Defaults to function's __name__. 399 description (Optional[str]): Optional custom description. Defaults to first line of docstring. 400 override (bool): If True, replaces an existing tool with the same name. 401 disabled (bool): If True, registers the function in a disabled state. 402 schema_extensions (Optional[Dict[str, Any]]): Optional dictionary mapping parameter names to 403 schema extensions that override or add to the auto-generated schema. 404 405 Raises: 406 ValueError: If the callable is invalid or duplicate name found without override. 407 408 Example: 409 interactor.add_function( 410 my_tool, 411 override=True, 412 disabled=False, 413 schema_extensions={ 414 "param1": {"minimum": 0, "maximum": 100}, 415 "param2": {"format": "email"} 416 } 417 ) 418 """ 419 def _python_type_to_schema(ptype: Any) -> dict: 420 """Convert a Python type annotation to OpenAI-compatible JSON Schema.""" 421 # Handle None case 422 if ptype is None: 423 return {"type": "null"} 424 425 # Get the origin and arguments of the type 426 origin = get_origin(ptype) 427 args = get_args(ptype) 428 429 # Handle Union types (including Optional) 430 if origin is Union: 431 # Check for Optional (Union with None) 432 none_type = type(None) 433 if none_type in args: 434 non_none = [a for a in args if a is not none_type] 435 if len(non_none) == 1: 436 inner = _python_type_to_schema(non_none[0]) 437 inner_copy = inner.copy() 438 inner_copy["nullable"] = True 439 return inner_copy 440 # Multiple types excluding None 441 types = [_python_type_to_schema(a) for a in non_none] 442 return {"anyOf": types, "nullable": True} 443 # Regular Union without None 444 return {"anyOf": [_python_type_to_schema(a) for a in args]} 445 446 # Handle List and similar container types 447 if origin in (list, List): 448 item_type = args[0] if args else Any 449 if item_type is Any: 450 return {"type": "array"} 451 return {"type": "array", "items": _python_type_to_schema(item_type)} 452 453 # Handle Dict types with typing info 454 if origin in (dict, Dict): 455 if not args or len(args) != 2: 456 return {"type": "object"} 457 458 key_type, val_type = args 459 # We can only really use val_type in JSON Schema 460 if val_type is not Any and val_type is not object: 461 return { 462 "type": "object", 463 "additionalProperties": _python_type_to_schema(val_type) 464 } 465 return {"type": "object"} 466 467 # Handle Literal types for enums 468 if origin is Literal: 469 values = args 470 # Try to determine type from values 471 if all(isinstance(v, str) for v in values): 472 return {"type": "string", "enum": list(values)} 473 elif all(isinstance(v, bool) for v in values): 474 return {"type": "boolean", "enum": list(values)} 475 elif all(isinstance(v, (int, float)) for v in values): 476 return {"type": "number", "enum": list(values)} 477 else: 478 # Mixed types, use anyOf 479 return {"anyOf": [{"type": _get_json_type(v), "enum": [v]} for v in values]} 480 481 # Handle basic types 482 if ptype is str: 483 return {"type": "string"} 484 if ptype is int: 485 return {"type": "integer"} 486 if ptype is float: 487 return {"type": "number"} 488 if ptype is bool: 489 return {"type": "boolean"} 490 491 # Handle common datetime types 492 if ptype is datetime: 493 return {"type": "string", "format": "date-time"} 494 if ptype is date: 495 return {"type": "string", "format": "date"} 496 497 # Handle UUID 498 if ptype is uuid.UUID: 499 return {"type": "string", "format": "uuid"} 500 501 # Default to object for any other types 502 return {"type": "object"} 503 504 def _get_json_type(value): 505 """Get the JSON Schema type name for a Python value. 506 507 This helper function maps Python types to their corresponding 508 JSON Schema type names. It handles basic types and provides 509 sensible defaults for complex types. 510 511 Args: 512 value: The Python value to get the JSON type for 513 514 Returns: 515 str: The JSON Schema type name ('string', 'number', 'boolean', 516 'array', 'object', or 'object' as default) 517 """ 518 if isinstance(value, str): 519 return "string" 520 elif isinstance(value, bool): 521 return "boolean" 522 elif isinstance(value, int) or isinstance(value, float): 523 return "number" 524 elif isinstance(value, list): 525 return "array" 526 elif isinstance(value, dict): 527 return "object" 528 else: 529 return "object" # Default 530 531 def _parse_param_docs(docstring: str) -> dict: 532 """Extract parameter descriptions from a docstring.""" 533 if not docstring: 534 return {} 535 536 lines = docstring.splitlines() 537 param_docs = {} 538 current_param = None 539 in_params = False 540 541 # Regular expressions for finding parameter sections and param lines 542 param_section_re = re.compile(r"^(Args|Parameters):\s*$") 543 param_line_re = re.compile(r"^\s{4}(\w+)\s*(?:\([^\)]*\))?:\s*(.*)") 544 545 for line in lines: 546 # Check if we're entering the parameters section 547 if param_section_re.match(line.strip()): 548 in_params = True 549 continue 550 551 if in_params: 552 # Skip empty lines 553 if not line.strip(): 554 continue 555 556 # Check for a parameter definition line 557 match = param_line_re.match(line) 558 if match: 559 current_param = match.group(1) 560 param_docs[current_param] = match.group(2).strip() 561 # Check for continuation of a parameter description 562 elif current_param and line.startswith(" " * 8): 563 param_docs[current_param] += " " + line.strip() 564 # If we see a line that doesn't match our patterns, we're out of the params section 565 else: 566 current_param = None 567 568 return param_docs 569 570 # Start of main function logic 571 572 # Skip if tools are disabled 573 if not self.tools_enabled: 574 return 575 576 # Validate input callable 577 if not external_callable: 578 raise ValueError("A valid external callable must be provided.") 579 580 # Set function name, either from parameter or from callable's __name__ 581 function_name = name or external_callable.__name__ 582 583 # Try to get docstring and extract description 584 try: 585 docstring = inspect.getdoc(external_callable) 586 description = description or (docstring.split("\n")[0].strip() if docstring else "No description provided.") 587 except Exception as e: 588 self._log(f"[TOOL] Warning: Could not extract docstring from {function_name}: {e}", level="warning") 589 docstring = "" 590 description = description or "No description provided." 591 592 # Extract parameter documentation from docstring 593 param_docs = _parse_param_docs(docstring) 594 595 # Handle conflicts with existing functions 596 if override: 597 self.delete_function(function_name) 598 elif any(t["function"]["name"] == function_name for t in self.tools): 599 raise ValueError(f"Function '{function_name}' is already registered. Use override=True to replace.") 600 601 # Try to get function signature for parameter info 602 try: 603 signature = inspect.signature(external_callable) 604 except (ValueError, TypeError) as e: 605 raise ValueError(f"Cannot inspect callable '{function_name}': {e}") 606 607 # Process parameters to build schema 608 properties = {} 609 required = [] 610 611 for param_name, param in signature.parameters.items(): 612 # Skip self, cls parameters for instance/class methods 613 if param_name in ("self", "cls") and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: 614 continue 615 616 # Get parameter annotation, defaulting to Any 617 annotation = param.annotation if param.annotation != inspect.Parameter.empty else Any 618 619 try: 620 # Convert Python type to JSON Schema 621 schema = _python_type_to_schema(annotation) 622 623 # Add description from docstring or create a default one 624 schema["description"] = param_docs.get(param_name, f"{param_name} parameter") 625 626 # Add to properties 627 properties[param_name] = schema 628 629 # If no default value is provided, parameter is required 630 if param.default == inspect.Parameter.empty: 631 required.append(param_name) 632 self._log(f"[TOOL] Parameter '{param_name}' is required", level="debug") 633 else: 634 self._log(f"[TOOL] Parameter '{param_name}' has default value: {param.default}", level="debug") 635 636 except Exception as e: 637 self._log(f"[TOOL] Error processing parameter {param_name} for {function_name}: {e}", level="error") 638 # Add a basic object schema as fallback 639 properties[param_name] = { 640 "type": "string", # Default to string instead of object for better compatibility 641 "description": f"{param_name} parameter (type conversion failed)" 642 } 643 644 # For parameters with no default value, mark as required even if processing failed 645 if param.default == inspect.Parameter.empty: 646 required.append(param_name) 647 self._log(f"[TOOL] Parameter '{param_name}' marked as required despite conversion failure", level="debug") 648 649 # Apply schema extensions if provided 650 if schema_extensions: 651 for param_name, extensions in schema_extensions.items(): 652 if param_name in properties: 653 properties[param_name].update(extensions) 654 655 # Create parameters object with proper placement of 'required' field 656 parameters = { 657 "type": "object", 658 "properties": properties, 659 } 660 661 # Only add required field if there are required parameters 662 if required: 663 parameters["required"] = required 664 665 # Build the final tool specification 666 tool_spec = { 667 "type": "function", 668 "function": { 669 "name": function_name, 670 "description": description, 671 "parameters": parameters 672 } 673 } 674 675 # Set disabled flag if requested 676 if disabled: 677 tool_spec["function"]["disabled"] = True 678 679 # Add to tools list 680 self.tools.append(tool_spec) 681 682 # Make the function available as an attribute on the instance 683 setattr(self, function_name, external_callable) 684 685 # Log the registration with detailed information 686 self._log(f"[TOOL] Registered function '{function_name}' with {len(properties)} parameters", level="info") 687 if required: 688 self._log(f"[TOOL] Required parameters: {required}", level="info") 689 690 return function_name # Return the name for reference
Register a function for LLM tool calling with full type hints and metadata.
Args: external_callable (Callable): The function to register. name (Optional[str]): Optional custom name. Defaults to function's __name__. description (Optional[str]): Optional custom description. Defaults to first line of docstring. override (bool): If True, replaces an existing tool with the same name. disabled (bool): If True, registers the function in a disabled state. schema_extensions (Optional[Dict[str, Any]]): Optional dictionary mapping parameter names to schema extensions that override or add to the auto-generated schema.
Raises: ValueError: If the callable is invalid or duplicate name found without override.
Example: interactor.add_function( my_tool, override=True, disabled=False, schema_extensions={ "param1": {"minimum": 0, "maximum": 100}, "param2": {"format": "email"} } )
693 def disable_function(self, name: str) -> bool: 694 """ 695 Disable a registered tool function by name. 696 697 This marks the function as inactive for tool calling without removing it from the internal registry. 698 The function remains visible in the tool listing but is skipped during tool selection by the LLM. 699 700 Args: 701 name (str): The name of the function to disable. 702 703 Returns: 704 bool: True if the function was found and disabled, False otherwise. 705 706 Example: 707 interactor.disable_function("extract_text") 708 """ 709 for tool in self.tools: 710 if tool["function"]["name"] == name: 711 tool["function"]["disabled"] = True 712 return True 713 return False
Disable a registered tool function by name.
This marks the function as inactive for tool calling without removing it from the internal registry. The function remains visible in the tool listing but is skipped during tool selection by the LLM.
Args: name (str): The name of the function to disable.
Returns: bool: True if the function was found and disabled, False otherwise.
Example: interactor.disable_function("extract_text")
716 def enable_function(self, name: str) -> bool: 717 """ 718 Re-enable a previously disabled tool function by name. 719 720 This removes the 'disabled' flag from a tool function, making it available again for LLM use. 721 722 Args: 723 name (str): The name of the function to enable. 724 725 Returns: 726 bool: True if the function was found and enabled, False otherwise. 727 728 Example: 729 interactor.enable_function("extract_text") 730 """ 731 for tool in self.tools: 732 if tool["function"]["name"] == name: 733 tool["function"].pop("disabled", None) 734 return True 735 return False
Re-enable a previously disabled tool function by name.
This removes the 'disabled' flag from a tool function, making it available again for LLM use.
Args: name (str): The name of the function to enable.
Returns: bool: True if the function was found and enabled, False otherwise.
Example: interactor.enable_function("extract_text")
738 def delete_function(self, name: str) -> bool: 739 """ 740 Permanently remove a registered tool function from the Interactor. 741 742 This deletes both the tool metadata and the callable attribute, making it fully inaccessible 743 from the active session. Useful for dynamically trimming the toolset. 744 745 Args: 746 name (str): The name of the function to delete. 747 748 Returns: 749 bool: True if the function was found and removed, False otherwise. 750 751 Example: 752 interactor.delete_function("extract_text") 753 """ 754 before = len(self.tools) 755 self.tools = [tool for tool in self.tools if tool["function"]["name"] != name] 756 if hasattr(self, name): 757 delattr(self, name) 758 return len(self.tools) < before
Permanently remove a registered tool function from the Interactor.
This deletes both the tool metadata and the callable attribute, making it fully inaccessible from the active session. Useful for dynamically trimming the toolset.
Args: name (str): The name of the function to delete.
Returns: bool: True if the function was found and removed, False otherwise.
Example: interactor.delete_function("extract_text")
761 def list_functions(self) -> List[Dict[str, Any]]: 762 """Get the list of registered functions for tool calling. 763 764 Returns: 765 List[Dict[str, Any]]: List of registered functions. 766 """ 767 return self.tools
Get the list of registered functions for tool calling.
Returns: List[Dict[str, Any]]: List of registered functions.
770 def list_models(self, update=False): 771 """list available models""" 772 if update: 773 self._update_available_models() 774 775 return self.models
list available models
897 def interact( 898 self, 899 user_input: Optional[str], 900 quiet: bool = False, 901 tools: bool = True, 902 stream: bool = True, 903 markdown: bool = False, 904 model: Optional[str] = None, 905 output_callback: Optional[Callable[[str], None]] = None, 906 session_id: Optional[str] = None, 907 raw: Optional[bool] = None, 908 tool_suppress: bool = True, 909 timeout: float = 60.0 910 ) -> Union[Optional[str], "TokenStream"]: 911 """Main universal gateway for all LLM interaction. 912 913 This function serves as the single entry point for all interactions with the language model. 914 When `raw=False` (default), it handles the interaction internally and returns the full response. 915 When `raw=True`, it returns a context manager that yields chunks of the response for custom handling. 916 917 Args: 918 user_input: Text input from the user. 919 quiet: If True, don't print status info or progress. 920 tools: Enable (True) or disable (False) tool calling. 921 stream: Enable (True) or disable (False) streaming responses. 922 markdown: If True, renders content as markdown. 923 model: Optional model override. 924 output_callback: Optional callback to handle the output. 925 session_id: Optional session ID to load messages from. 926 raw: If True, return a context manager instead of handling the interaction internally. 927 If None, use the class-level setting from __init__. 928 tool_suppress: If True and raw=True, filter out tool-related status messages. 929 timeout: Maximum time in seconds to wait for the stream to complete when raw=True. 930 931 Returns: 932 If raw=False: The complete response from the model as a string, or None if there was an error. 933 If raw=True: A context manager that yields chunks of the response as they arrive. 934 935 Example with default mode: 936 response = ai.interact("Tell me a joke") 937 938 Example with raw mode: 939 with ai.interact("Tell me a joke", raw=True) as stream: 940 for chunk in stream: 941 print(chunk, end="", flush=True) 942 """ 943 if not user_input: 944 return None 945 946 if quiet or self.quiet: 947 markdown = False 948 stream = False 949 950 # Determine if we should use raw mode 951 # If raw parameter is explicitly provided, use that; otherwise use class setting 952 use_raw = self.raw if raw is None else raw 953 954 # If raw mode is requested, delegate to interact_raw 955 if use_raw: 956 return self._interact_raw( 957 user_input=user_input, 958 tools=tools, 959 model=model, 960 session_id=session_id, 961 tool_suppress=tool_suppress, 962 timeout=timeout 963 ) 964 965 # Setup model if specified 966 if model: 967 self._setup_client(model) 968 self._setup_encoding() 969 970 # Session handling 971 if self.session_enabled and session_id: 972 self.session_id = session_id 973 self.session_load(session_id) 974 975 # Add user message using messages_add 976 self.messages_add(role="user", content=user_input) 977 978 # Log token count estimate 979 token_count = self._count_tokens(self.history) 980 if not quiet: 981 print(f"[dim]Estimated tokens in context: {token_count} / {self.context_length}[/dim]") 982 983 # Make sure we have enough context space 984 if token_count > self.context_length: 985 if self._cycle_messages(): 986 if not quiet: 987 print("[red]Context window exceeded. Cannot proceed.[/red]") 988 return None 989 990 # Log user input 991 self._log(f"[USER] {user_input}") 992 993 # Handle the actual interaction with complete streaming for all responses 994 result = asyncio.run(self._interact_async_core( 995 user_input=user_input, 996 quiet=quiet, 997 tools=tools, 998 stream=stream, 999 markdown=markdown, 1000 output_callback=output_callback 1001 )) 1002 1003 # Log completion for this interaction 1004 self._log(f"[INTERACTION] Completed with {len(self.history)} total messages") 1005 1006 return result
Main universal gateway for all LLM interaction.
This function serves as the single entry point for all interactions with the language model.
When raw=False
(default), it handles the interaction internally and returns the full response.
When raw=True
, it returns a context manager that yields chunks of the response for custom handling.
Args: user_input: Text input from the user. quiet: If True, don't print status info or progress. tools: Enable (True) or disable (False) tool calling. stream: Enable (True) or disable (False) streaming responses. markdown: If True, renders content as markdown. model: Optional model override. output_callback: Optional callback to handle the output. session_id: Optional session ID to load messages from. raw: If True, return a context manager instead of handling the interaction internally. If None, use the class-level setting from __init__. tool_suppress: If True and raw=True, filter out tool-related status messages. timeout: Maximum time in seconds to wait for the stream to complete when raw=True.
Returns: If raw=False: The complete response from the model as a string, or None if there was an error. If raw=True: A context manager that yields chunks of the response as they arrive.
Example with default mode: response = ai.interact("Tell me a joke")
Example with raw mode: with ai.interact("Tell me a joke", raw=True) as stream: for chunk in stream: print(chunk, end="", flush=True)
2207 def messages_add( 2208 self, 2209 role: str, 2210 content: Any, 2211 tool_info: Optional[Dict] = None, 2212 normalize: bool = True 2213 ) -> str: 2214 """ 2215 Add a message to the standardized session_history and then update SDK-specific history. 2216 2217 This method is the central point for all message additions to the conversation. 2218 2219 Args: 2220 role: The role of the message ("user", "assistant", "system", "tool") 2221 content: The message content (text or structured) 2222 tool_info: Optional tool-related metadata 2223 normalize: Whether to normalize history after adding this message 2224 2225 Returns: 2226 str: Unique ID of the added message 2227 """ 2228 # Generate a unique message ID 2229 message_id = str(uuid.uuid4()) 2230 2231 # Create the standardized message for session_history 2232 timestamp = datetime.now(timezone.utc).isoformat() 2233 2234 # Store system messages directly 2235 if role == "system": 2236 self.system = content 2237 2238 # Create standard format message 2239 standard_message = { 2240 "role": role, 2241 "content": content, 2242 "id": message_id, 2243 "timestamp": timestamp, 2244 "metadata": { 2245 "sdk": self.sdk 2246 } 2247 } 2248 2249 # Add tool info if provided 2250 if tool_info: 2251 standard_message["metadata"]["tool_info"] = tool_info 2252 2253 # Add to session_history 2254 if not hasattr(self, "session_history"): 2255 self.session_history = [] 2256 2257 self.session_history.append(standard_message) 2258 2259 # Save to persistent session if enabled 2260 if self.session_enabled and self.session_id: 2261 # Convert standard message to session-compatible format 2262 session_msg = { 2263 "role": role, 2264 "content": content, 2265 "id": message_id, 2266 "timestamp": timestamp 2267 } 2268 2269 # Add tool-related fields if present 2270 if tool_info: 2271 for key, value in tool_info.items(): 2272 session_msg[key] = value 2273 2274 # Store in session 2275 self.session.msg_insert(self.session_id, session_msg) 2276 2277 # Update the SDK-specific format in self.history by running the normalizer 2278 if normalize: 2279 # We only need to normalize the most recent message for efficiency 2280 # Pass a flag indicating we're just normalizing a new message 2281 self._normalizer(force=False, new_message_only=True) 2282 2283 # Log the added message 2284 self._log(f"[MESSAGE ADDED] {role}: {str(content)[:50]}...") 2285 2286 return message_id
Add a message to the standardized session_history and then update SDK-specific history.
This method is the central point for all message additions to the conversation.
Args: role: The role of the message ("user", "assistant", "system", "tool") content: The message content (text or structured) tool_info: Optional tool-related metadata normalize: Whether to normalize history after adding this message
Returns: str: Unique ID of the added message
2289 def messages_system(self, prompt: str): 2290 """Set or retrieve the current system prompt. 2291 2292 This method manages the system prompt that guides the AI's behavior. 2293 It can be used to both set a new system prompt and retrieve the current one. 2294 When setting a new prompt, it updates the system message in the conversation 2295 history and persists it to the session if enabled. 2296 2297 Args: 2298 prompt (str): The new system prompt to set. If empty or None, 2299 returns the current system prompt without changes. 2300 2301 Returns: 2302 str: The current system prompt after any updates 2303 """ 2304 if not isinstance(prompt, str) or not prompt: 2305 return self.system 2306 2307 # If the prompt hasn't changed, don't do anything 2308 if self.system == prompt: 2309 return self.system 2310 2311 # Update the system prompt 2312 old_system = self.system 2313 self.system = prompt 2314 2315 # For OpenAI, update or insert the system message in history 2316 if self.sdk == "openai": 2317 # Check if there's already a system message 2318 system_index = next((i for i, msg in enumerate(self.history) 2319 if msg.get("role") == "system"), None) 2320 2321 if system_index is not None: 2322 # Update existing system message 2323 self.history[system_index]["content"] = prompt 2324 else: 2325 # Insert new system message at the beginning 2326 self.history.insert(0, {"role": "system", "content": prompt}) 2327 2328 # For Anthropic, system message is not part of history, just save it for API calls 2329 2330 # Log to session only if prompt actually changed 2331 if self.session_enabled and self.session_id and old_system != prompt: 2332 self.session.msg_insert(self.session_id, {"role": "system", "content": prompt}) 2333 2334 return self.system
Set or retrieve the current system prompt.
This method manages the system prompt that guides the AI's behavior. It can be used to both set a new system prompt and retrieve the current one. When setting a new prompt, it updates the system message in the conversation history and persists it to the session if enabled.
Args: prompt (str): The new system prompt to set. If empty or None, returns the current system prompt without changes.
Returns: str: The current system prompt after any updates
2337 def messages(self) -> list: 2338 """Return full session messages (persisted or in-memory).""" 2339 if self.session_enabled and self.session_id: 2340 return self.session.load_full(self.session_id).get("messages", []) 2341 return self.session_history
Return full session messages (persisted or in-memory).
2344 def messages_length(self) -> int: 2345 """Calculate the total token count for the message history.""" 2346 if not self.encoding: 2347 return 0 2348 2349 total_tokens = 0 2350 for message in self.history: 2351 if message.get("content"): 2352 total_tokens += len(self.encoding.encode(message["content"])) 2353 if message.get("tool_calls"): 2354 for tool_call in message["tool_calls"]: 2355 if tool_call.get("function"): 2356 total_tokens += len(self.encoding.encode(tool_call["function"].get("name", ""))) 2357 total_tokens += len(self.encoding.encode(tool_call["function"].get("arguments", ""))) 2358 return total_tokens
Calculate the total token count for the message history.
2361 def session_load(self, session_id: Optional[str]): 2362 """Load and normalize messages for a specific session. 2363 2364 This method loads a conversation session from persistent storage and 2365 normalizes the messages to the current SDK format. It handles system 2366 messages, tool calls, and maintains message ordering. If loading fails, 2367 it resets to an empty session with the default system prompt. 2368 2369 Args: 2370 session_id (Optional[str]): The ID of the session to load. 2371 If None, resets to in-memory mode. 2372 2373 Note: 2374 This method will update both session_history and history to match 2375 the loaded session's state. It also ensures proper SDK-specific 2376 message formatting. 2377 """ 2378 self.session_id = session_id 2379 self._last_session_id = session_id 2380 2381 if self.session_enabled and session_id: 2382 try: 2383 # Load raw session data 2384 session_data = self.session.load_full(session_id) 2385 messages = session_data.get("messages", []) 2386 2387 # Convert session format to our standard format 2388 self.session_history = [] 2389 2390 # Track the most recent system message 2391 latest_system_msg = None 2392 2393 for msg in messages: 2394 # Extract fields 2395 role = msg.get("role", "user") 2396 content = msg.get("content", "") 2397 msg_id = msg.get("id", str(uuid.uuid4())) 2398 timestamp = msg.get("timestamp", datetime.now(timezone.utc).isoformat()) 2399 2400 # If this is a system message, track it but don't add to session_history yet 2401 if role == "system": 2402 if latest_system_msg is None or timestamp > latest_system_msg["timestamp"]: 2403 latest_system_msg = { 2404 "role": role, 2405 "content": content, 2406 "id": msg_id, 2407 "timestamp": timestamp, 2408 "metadata": {"sdk": self.sdk} 2409 } 2410 continue 2411 2412 # Build tool_info if present 2413 tool_info = None 2414 if any(key in msg for key in ["tool_use_id", "tool_call_id", "name", "arguments"]): 2415 tool_info = { 2416 "id": msg.get("tool_use_id") or msg.get("tool_call_id"), 2417 "name": msg.get("name", "unknown_tool"), 2418 "arguments": msg.get("arguments", {}) 2419 } 2420 2421 # Create standard message 2422 standard_msg = { 2423 "role": role, 2424 "content": content, 2425 "id": msg_id, 2426 "timestamp": timestamp, 2427 "metadata": { 2428 "sdk": self.sdk 2429 } 2430 } 2431 2432 if tool_info: 2433 standard_msg["metadata"]["tool_info"] = tool_info 2434 2435 self.session_history.append(standard_msg) 2436 2437 # If we found a system message, update the system property and add to history 2438 if latest_system_msg: 2439 self.system = latest_system_msg["content"] 2440 # Insert at the beginning of session_history 2441 self.session_history.insert(0, latest_system_msg) 2442 else: 2443 # If no system message was found, add the current system message 2444 self.messages_add(role="system", content=self.system) 2445 2446 # Normalize to current SDK format 2447 self._normalizer(force=True) 2448 2449 self._log(f"[SESSION] Switched to session '{session_id}'") 2450 except Exception as e: 2451 self.logger.error(f"Failed to load session '{session_id}': {e}") 2452 self.session_reset() 2453 else: 2454 # Reset to empty state with system message 2455 self.session_reset()
Load and normalize messages for a specific session.
This method loads a conversation session from persistent storage and normalizes the messages to the current SDK format. It handles system messages, tool calls, and maintains message ordering. If loading fails, it resets to an empty session with the default system prompt.
Args: session_id (Optional[str]): The ID of the session to load. If None, resets to in-memory mode.
Note: This method will update both session_history and history to match the loaded session's state. It also ensures proper SDK-specific message formatting.
2458 def session_reset(self): 2459 """Reset the current session state and reinitialize to default system prompt. 2460 2461 This method performs a complete reset of the conversation state: 2462 1. Clears all message history 2463 2. Disables session ID tracking 2464 3. Returns to in-memory mode 2465 4. Reinitializes with the default system prompt 2466 2467 The reset is useful for starting fresh conversations or recovering 2468 from error states. It maintains the basic system configuration while 2469 clearing all conversation context. 2470 """ 2471 self.session_id = None 2472 self._last_session_id = None 2473 2474 # Clear histories 2475 self.session_history = [] 2476 self.history = [] 2477 2478 # Reapply the system message 2479 if hasattr(self, "system") and self.system: 2480 # Add to session_history 2481 self.messages_add(role="system", content=self.system) 2482 else: 2483 # Ensure we have a default system message 2484 self.system = "You are a helpful Assistant." 2485 self.messages_add(role="system", content=self.system) 2486 2487 self._log("[SESSION] Reset to in-memory mode")
Reset the current session state and reinitialize to default system prompt.
This method performs a complete reset of the conversation state:
- Clears all message history
- Disables session ID tracking
- Returns to in-memory mode
- Reinitializes with the default system prompt
The reset is useful for starting fresh conversations or recovering from error states. It maintains the basic system configuration while clearing all conversation context.
2931 def track_token_usage(self): 2932 """Track and return token usage across the conversation history. 2933 2934 This method maintains a history of token usage measurements and provides 2935 current usage statistics. It tracks: 2936 - Current token count 2937 - Context length limit 2938 - Usage percentage 2939 - Historical measurements 2940 - Current provider and model info 2941 2942 Returns: 2943 dict: Dictionary containing: 2944 - current: Current token count 2945 - limit: Maximum context length 2946 - percentage: Usage as percentage of limit 2947 - history: Last 10 measurements with timestamps 2948 - provider: Current provider name 2949 - model: Current model name 2950 2951 Note: 2952 The history is limited to the last 100 measurements to prevent 2953 unbounded memory growth. 2954 """ 2955 if not hasattr(self, "_token_history"): 2956 self._token_history = [] 2957 2958 # Count current tokens 2959 current_count = self._count_tokens(self.history) 2960 2961 # Add to history 2962 timestamp = datetime.now(timezone.utc).isoformat() 2963 self._token_history.append({ 2964 "timestamp": timestamp, 2965 "count": current_count, 2966 "limit": self.context_length, 2967 "provider": self.provider, 2968 "model": self.model 2969 }) 2970 2971 # Keep only the last 100 measurements to avoid unlimited growth 2972 if len(self._token_history) > 100: 2973 self._token_history = self._token_history[-100:] 2974 2975 # Return current tracking info 2976 return { 2977 "current": current_count, 2978 "limit": self.context_length, 2979 "percentage": round((current_count / self.context_length) * 100, 1) if self.context_length else 0, 2980 "history": self._token_history[-10:], # Return last 10 measurements 2981 "provider": self.provider, 2982 "model": self.model 2983 }
Track and return token usage across the conversation history.
This method maintains a history of token usage measurements and provides current usage statistics. It tracks:
- Current token count
- Context length limit
- Usage percentage
- Historical measurements
- Current provider and model info
Returns: dict: Dictionary containing: - current: Current token count - limit: Maximum context length - percentage: Usage as percentage of limit - history: Last 10 measurements with timestamps - provider: Current provider name - model: Current model name
Note: The history is limited to the last 100 measurements to prevent unbounded memory growth.
2986 def get_message_token_breakdown(self): 2987 """Analyze token usage by message type and provide a detailed breakdown. 2988 2989 This method performs a detailed analysis of token usage across the 2990 conversation history, breaking down usage by: 2991 - Message role (system, user, assistant, tool) 2992 - Content type (text, tool calls, tool results) 2993 - Individual message statistics 2994 2995 Returns: 2996 dict: Token usage breakdown containing: 2997 - total: Total tokens used 2998 - by_role: Tokens used by each role 2999 - by_type: Tokens used by content type 3000 - messages: List of individual message stats including: 3001 - index: Message position 3002 - role: Message role 3003 - tokens: Tokens used 3004 - has_tools: Whether message contains tool calls 3005 3006 Note: 3007 This analysis is useful for understanding token usage patterns 3008 and optimizing conversation context. 3009 """ 3010 breakdown = { 3011 "total": 0, 3012 "by_role": { 3013 "system": 0, 3014 "user": 0, 3015 "assistant": 0, 3016 "tool": 0 3017 }, 3018 "by_type": { 3019 "text": 0, 3020 "tool_calls": 0, 3021 "tool_results": 0 3022 }, 3023 "messages": [] 3024 } 3025 3026 # Analyze each message 3027 for i, msg in enumerate(self.history): 3028 msg_tokens = self._count_tokens([msg]) 3029 role = msg.get("role", "unknown") 3030 3031 # Track by role 3032 if role in breakdown["by_role"]: 3033 breakdown["by_role"][role] += msg_tokens 3034 3035 # Track by content type 3036 if role == "assistant" and msg.get("tool_calls"): 3037 breakdown["by_type"]["tool_calls"] += msg_tokens 3038 elif role == "tool": 3039 breakdown["by_type"]["tool_results"] += msg_tokens 3040 else: 3041 breakdown["by_type"]["text"] += msg_tokens 3042 3043 # Add individual message data 3044 breakdown["messages"].append({ 3045 "index": i, 3046 "role": role, 3047 "tokens": msg_tokens, 3048 "has_tools": bool(msg.get("tool_calls") or msg.get("tool_use") or 3049 (isinstance(msg.get("content"), list) and 3050 any(isinstance(c, dict) and c.get("type") in ["tool_use", "tool_result"] 3051 for c in msg.get("content", [])))) 3052 }) 3053 3054 # Update tota? 3055 breakdown["total"] += msg_tokens 3056 3057 return breakdown
Analyze token usage by message type and provide a detailed breakdown.
This method performs a detailed analysis of token usage across the conversation history, breaking down usage by:
- Message role (system, user, assistant, tool)
- Content type (text, tool calls, tool results)
- Individual message statistics
Returns: dict: Token usage breakdown containing: - total: Total tokens used - by_role: Tokens used by each role - by_type: Tokens used by content type - messages: List of individual message stats including: - index: Message position - role: Message role - tokens: Tokens used - has_tools: Whether message contains tool calls
Note: This analysis is useful for understanding token usage patterns and optimizing conversation context.
19class Session: 20 def __init__(self, directory: str = None): 21 """ 22 Initialize the session manager and ensure the session directory exists. 23 24 Args: 25 directory (str): Filesystem path for session storage. Must not be None or empty. 26 27 Raises: 28 ValueError: If directory is None or not a string. 29 OSError: If the directory cannot be created or accessed. 30 """ 31 if not directory: 32 raise ValueError("Session directory must be a valid non-empty string path.") 33 34 try: 35 self.path = Path(os.path.expanduser(directory)) 36 self.path.mkdir(parents=True, exist_ok=True) 37 except Exception as e: 38 raise OSError(f"Failed to initialize session directory '{directory}': {e}") 39 40 # --------------------------- 41 # Core CRUD 42 # --------------------------- 43 44 def list(self) -> List[Dict]: 45 """ 46 Return metadata for all sessions in the directory. 47 48 Returns: 49 List[Dict]: Sorted list of session metadata dictionaries. 50 """ 51 out = [] 52 for file in self.path.glob("*.json"): 53 try: 54 with open(file, "r") as f: 55 d = json.load(f) 56 out.append({ 57 "id": d.get("id"), 58 "name": d.get("name"), 59 "created": d.get("created"), 60 "tags": d.get("tags", []), 61 "summary": d.get("summary") 62 }) 63 except Exception: 64 continue 65 return sorted(out, key=lambda x: x["created"], reverse=True) 66 67 def create(self, name: str, tags: Optional[List[str]] = None) -> str: 68 """ 69 Create and persist a new session. 70 71 Args: 72 name (str): Name of the new session. 73 tags (List[str], optional): Optional list of tags. 74 75 Returns: 76 str: Unique session ID of the new session. 77 """ 78 sid = str(uuid.uuid4()) 79 session = { 80 "id": sid, 81 "name": name, 82 "created": datetime.now(timezone.utc).isoformat(), 83 "parent": None, 84 "branch_point": None, 85 "tags": tags or [], 86 "summary": None, 87 "messages": [] 88 } 89 self._save_file(sid, session) 90 return sid 91 92 def load(self, session_id: str) -> List[Dict]: 93 """ 94 Return OpenAI-compatible message list from a session. 95 96 Filters out internal keys and leaves only standard API-compatible fields. 97 98 Args: 99 session_id (str): ID of the session to load. 100 101 Returns: 102 List[Dict]: List of clean message dictionaries. 103 """ 104 session = self._read_file(session_id) 105 return [ 106 {k: v for k, v in m.items() if k in { 107 "role", "content", "tool_calls", "name", "function_call", "tool_call_id" 108 }} for m in session.get("messages", []) 109 ] 110 111 def load_full(self, session_id: str) -> Dict: 112 """ 113 Return the complete session file as-is. 114 115 Args: 116 session_id (str): ID of the session. 117 118 Returns: 119 Dict: Entire raw session data. 120 """ 121 return self._read_file(session_id) 122 123 def delete(self, session_id: str): 124 """ 125 Delete a session file from disk. 126 127 Args: 128 session_id (str): ID of the session to delete. 129 """ 130 file = self.path / f"{session_id}.json" 131 if file.exists(): 132 file.unlink() 133 134 def update(self, session_id: str, key: str, value: Any): 135 """ 136 Update a top-level key in a session file. 137 138 Args: 139 session_id (str): Session ID. 140 key (str): Field to update. 141 value (Any): New value for the field. 142 """ 143 session = self._read_file(session_id) 144 session[key] = value 145 self._save_file(session_id, session) 146 147 # --------------------------- 148 # Message Operations 149 # --------------------------- 150 151 def msg_insert(self, session_id: str, message: Dict) -> str: 152 """ 153 Insert a new message into a session. 154 155 Args: 156 session_id (str): Session ID. 157 message (Dict): Message dictionary to insert. 158 159 Returns: 160 str: ID of the inserted message. 161 """ 162 session = self._read_file(session_id) 163 entry = { 164 "id": str(uuid.uuid4()), 165 "timestamp": datetime.now(timezone.utc).isoformat(), 166 **message 167 } 168 session["messages"].append(entry) 169 self._save_file(session_id, session) 170 return entry["id"] 171 172 def msg_get(self, session_id: str, message_id: str) -> Optional[Dict]: 173 """ 174 Retrieve a specific message from a session. 175 176 Args: 177 session_id (str): Session ID. 178 message_id (str): ID of the message to retrieve. 179 180 Returns: 181 Optional[Dict]: The message if found, else None. 182 """ 183 session = self._read_file(session_id) 184 for msg in session.get("messages", []): 185 if msg.get("id") == message_id: 186 return msg 187 return None 188 189 def msg_index(self, session_id: str, message_id: str) -> Optional[int]: 190 """ 191 Get the index of a message within a session. 192 193 Args: 194 session_id (str): Session ID. 195 message_id (str): Message ID. 196 197 Returns: 198 Optional[int]: Index if found, else None. 199 """ 200 session = self._read_file(session_id) 201 for i, msg in enumerate(session.get("messages", [])): 202 if msg.get("id") == message_id: 203 return i 204 return None 205 206 def msg_update(self, session_id: str, message_id: str, new_content: str) -> bool: 207 """ 208 Update the content of a specific message. 209 210 Args: 211 session_id (str): Session ID. 212 message_id (str): Message ID. 213 new_content (str): New content for the message. 214 215 Returns: 216 bool: True if update succeeded, False otherwise. 217 """ 218 session = self._read_file(session_id) 219 for m in session["messages"]: 220 if m.get("id") == message_id: 221 m["content"] = new_content 222 self._save_file(session_id, session) 223 return True 224 return False 225 226 def msg_delete(self, session_id: str, message_id: str) -> bool: 227 """ 228 Delete a message from a session. 229 230 Args: 231 session_id (str): Session ID. 232 message_id (str): Message ID. 233 234 Returns: 235 bool: True if deletion occurred, False otherwise. 236 """ 237 session = self._read_file(session_id) 238 before = len(session["messages"]) 239 session["messages"] = [m for m in session["messages"] if m.get("id") != message_id] 240 self._save_file(session_id, session) 241 return len(session["messages"]) < before 242 243 # --------------------------- 244 # Branching & Summarization 245 # --------------------------- 246 247 def branch(self, from_id: str, message_id: str, new_name: str) -> str: 248 """Create a new session by branching from a specific message. 249 250 This method creates a new session that branches from an existing one at a specific 251 message point. The new session inherits all messages up to and including the 252 specified message, then starts fresh from there. 253 254 Args: 255 from_id (str): ID of the source session to branch from. 256 message_id (str): ID of the message to branch at. 257 new_name (str): Name for the new branched session. 258 259 Returns: 260 str: ID of the newly created branched session. 261 262 Raises: 263 ValueError: If the source session or message ID is not found. 264 """ 265 # Get source session 266 source = self._read_file(from_id) 267 if not source: 268 raise ValueError(f"Source session '{from_id}' not found") 269 270 # Find the branch point 271 branch_index = self.msg_index(from_id, message_id) 272 if branch_index is None: 273 raise ValueError(f"Message '{message_id}' not found in session '{from_id}'") 274 275 # Create new session 276 new_id = self.create(new_name, source.get("tags", [])) 277 new_session = self._read_file(new_id) 278 279 # Copy messages up to branch point 280 new_session["messages"] = source["messages"][:branch_index + 1] 281 new_session["parent"] = from_id 282 new_session["branch_point"] = message_id 283 284 # Save and return 285 self._save_file(new_id, new_session) 286 return new_id 287 288 def summarize(self, interactor, session_id: str) -> str: 289 """Generate a summary of the session using the provided interactor. 290 291 This method uses the AI interactor to analyze the session content and generate 292 a concise summary. The summary is stored in the session metadata and returned. 293 294 Args: 295 interactor: An AI interactor instance capable of generating summaries. 296 session_id (str): ID of the session to summarize. 297 298 Returns: 299 str: The generated summary text. 300 301 Note: 302 The summary is automatically stored in the session metadata and can be 303 retrieved later using load_full(). 304 """ 305 session = self._read_file(session_id) 306 if not session: 307 return "" 308 309 # Get clean message list 310 messages = self.load(session_id) 311 if not messages: 312 return "" 313 314 # Generate summary 315 summary = interactor.interact( 316 "Summarize this conversation in 2-3 sentences:", 317 tools=False, 318 stream=False, 319 markdown=False 320 ) 321 322 # Store and return 323 session["summary"] = summary 324 self._save_file(session_id, session) 325 return summary 326 327 # --------------------------- 328 # Search Capabilities 329 # --------------------------- 330 331 def search(self, query: str, session_id: Optional[str] = None) -> List[Dict]: 332 """Search for messages containing the query text within a session or all sessions. 333 334 This method performs a case-insensitive text search across message content. 335 If a session_id is provided, only searches within that session. Otherwise, 336 searches across all sessions. 337 338 Args: 339 query (str): Text to search for. 340 session_id (Optional[str]): Optional session ID to limit search scope. 341 342 Returns: 343 List[Dict]: List of matching messages with their session context. 344 Each dict contains: 345 - session_id: ID of the containing session 346 - message: The matching message 347 - context: Surrounding messages for context 348 """ 349 results = [] 350 query = query.lower() 351 352 # Determine search scope 353 if session_id: 354 sessions = [(session_id, self._read_file(session_id))] 355 else: 356 sessions = [(f.stem, self._read_file(f.stem)) for f in self.path.glob("*.json")] 357 358 # Search each session 359 for sid, session in sessions: 360 if not session: 361 continue 362 363 messages = session.get("messages", []) 364 for i, msg in enumerate(messages): 365 content = str(msg.get("content", "")).lower() 366 if query in content: 367 # Get context (2 messages before and after) 368 start = max(0, i - 2) 369 end = min(len(messages), i + 3) 370 context = messages[start:end] 371 372 results.append({ 373 "session_id": sid, 374 "message": msg, 375 "context": context 376 }) 377 378 return results 379 380 def search_meta(self, query: str) -> List[Dict]: 381 """Search session metadata (name, tags, summary) for matching sessions. 382 383 This method performs a case-insensitive search across session metadata fields 384 including name, tags, and summary. It returns matching sessions with their 385 full metadata. 386 387 Args: 388 query (str): Text to search for in metadata. 389 390 Returns: 391 List[Dict]: List of matching session metadata dictionaries. 392 Each dict contains: 393 - id: Session ID 394 - name: Session name 395 - created: Creation timestamp 396 - tags: List of tags 397 - summary: Session summary if available 398 """ 399 results = [] 400 query = query.lower() 401 402 for file in self.path.glob("*.json"): 403 try: 404 with open(file, "r") as f: 405 session = json.load(f) 406 407 # Check metadata fields 408 name = str(session.get("name", "")).lower() 409 tags = [str(t).lower() for t in session.get("tags", [])] 410 summary = str(session.get("summary", "")).lower() 411 412 if (query in name or 413 any(query in tag for tag in tags) or 414 query in summary): 415 results.append({ 416 "id": session.get("id"), 417 "name": session.get("name"), 418 "created": session.get("created"), 419 "tags": session.get("tags", []), 420 "summary": session.get("summary") 421 }) 422 except Exception: 423 continue 424 425 return sorted(results, key=lambda x: x["created"], reverse=True) 426 427 # --------------------------- 428 # Internal I/O 429 # --------------------------- 430 431 def _read_file(self, session_id: str) -> Dict: 432 """Read and parse a session file from disk. 433 434 This internal method handles reading and parsing session files. 435 It ensures proper error handling and returns an empty session 436 structure if the file doesn't exist or is invalid. 437 438 Args: 439 session_id (str): ID of the session to read. 440 441 Returns: 442 Dict: Session data dictionary or empty session structure. 443 """ 444 file = self.path / f"{session_id}.json" 445 if not file.exists(): 446 return { 447 "id": session_id, 448 "name": "New Session", 449 "created": datetime.now(timezone.utc).isoformat(), 450 "messages": [] 451 } 452 453 try: 454 with open(file, "r") as f: 455 return json.load(f) 456 except Exception: 457 return { 458 "id": session_id, 459 "name": "New Session", 460 "created": datetime.now(timezone.utc).isoformat(), 461 "messages": [] 462 } 463 464 def _save_file(self, session_id: str, data: Dict): 465 """Write session data to disk. 466 467 This internal method handles writing session data to disk. 468 It ensures proper error handling and atomic writes. 469 470 Args: 471 session_id (str): ID of the session to save. 472 data (Dict): Session data to write. 473 474 Raises: 475 OSError: If the file cannot be written. 476 """ 477 file = self.path / f"{session_id}.json" 478 temp_file = file.with_suffix(".tmp") 479 480 try: 481 # Write to temporary file first 482 with open(temp_file, "w") as f: 483 json.dump(data, f, indent=2) 484 485 # Atomic rename 486 temp_file.replace(file) 487 except Exception as e: 488 if temp_file.exists(): 489 temp_file.unlink() 490 raise OSError(f"Failed to save session '{session_id}': {e}")
20 def __init__(self, directory: str = None): 21 """ 22 Initialize the session manager and ensure the session directory exists. 23 24 Args: 25 directory (str): Filesystem path for session storage. Must not be None or empty. 26 27 Raises: 28 ValueError: If directory is None or not a string. 29 OSError: If the directory cannot be created or accessed. 30 """ 31 if not directory: 32 raise ValueError("Session directory must be a valid non-empty string path.") 33 34 try: 35 self.path = Path(os.path.expanduser(directory)) 36 self.path.mkdir(parents=True, exist_ok=True) 37 except Exception as e: 38 raise OSError(f"Failed to initialize session directory '{directory}': {e}")
Initialize the session manager and ensure the session directory exists.
Args: directory (str): Filesystem path for session storage. Must not be None or empty.
Raises: ValueError: If directory is None or not a string. OSError: If the directory cannot be created or accessed.
44 def list(self) -> List[Dict]: 45 """ 46 Return metadata for all sessions in the directory. 47 48 Returns: 49 List[Dict]: Sorted list of session metadata dictionaries. 50 """ 51 out = [] 52 for file in self.path.glob("*.json"): 53 try: 54 with open(file, "r") as f: 55 d = json.load(f) 56 out.append({ 57 "id": d.get("id"), 58 "name": d.get("name"), 59 "created": d.get("created"), 60 "tags": d.get("tags", []), 61 "summary": d.get("summary") 62 }) 63 except Exception: 64 continue 65 return sorted(out, key=lambda x: x["created"], reverse=True)
Return metadata for all sessions in the directory.
Returns: List[Dict]: Sorted list of session metadata dictionaries.
67 def create(self, name: str, tags: Optional[List[str]] = None) -> str: 68 """ 69 Create and persist a new session. 70 71 Args: 72 name (str): Name of the new session. 73 tags (List[str], optional): Optional list of tags. 74 75 Returns: 76 str: Unique session ID of the new session. 77 """ 78 sid = str(uuid.uuid4()) 79 session = { 80 "id": sid, 81 "name": name, 82 "created": datetime.now(timezone.utc).isoformat(), 83 "parent": None, 84 "branch_point": None, 85 "tags": tags or [], 86 "summary": None, 87 "messages": [] 88 } 89 self._save_file(sid, session) 90 return sid
Create and persist a new session.
Args: name (str): Name of the new session. tags (List[str], optional): Optional list of tags.
Returns: str: Unique session ID of the new session.
92 def load(self, session_id: str) -> List[Dict]: 93 """ 94 Return OpenAI-compatible message list from a session. 95 96 Filters out internal keys and leaves only standard API-compatible fields. 97 98 Args: 99 session_id (str): ID of the session to load. 100 101 Returns: 102 List[Dict]: List of clean message dictionaries. 103 """ 104 session = self._read_file(session_id) 105 return [ 106 {k: v for k, v in m.items() if k in { 107 "role", "content", "tool_calls", "name", "function_call", "tool_call_id" 108 }} for m in session.get("messages", []) 109 ]
Return OpenAI-compatible message list from a session.
Filters out internal keys and leaves only standard API-compatible fields.
Args: session_id (str): ID of the session to load.
Returns: List[Dict]: List of clean message dictionaries.
111 def load_full(self, session_id: str) -> Dict: 112 """ 113 Return the complete session file as-is. 114 115 Args: 116 session_id (str): ID of the session. 117 118 Returns: 119 Dict: Entire raw session data. 120 """ 121 return self._read_file(session_id)
Return the complete session file as-is.
Args: session_id (str): ID of the session.
Returns: Dict: Entire raw session data.
123 def delete(self, session_id: str): 124 """ 125 Delete a session file from disk. 126 127 Args: 128 session_id (str): ID of the session to delete. 129 """ 130 file = self.path / f"{session_id}.json" 131 if file.exists(): 132 file.unlink()
Delete a session file from disk.
Args: session_id (str): ID of the session to delete.
134 def update(self, session_id: str, key: str, value: Any): 135 """ 136 Update a top-level key in a session file. 137 138 Args: 139 session_id (str): Session ID. 140 key (str): Field to update. 141 value (Any): New value for the field. 142 """ 143 session = self._read_file(session_id) 144 session[key] = value 145 self._save_file(session_id, session)
Update a top-level key in a session file.
Args: session_id (str): Session ID. key (str): Field to update. value (Any): New value for the field.
151 def msg_insert(self, session_id: str, message: Dict) -> str: 152 """ 153 Insert a new message into a session. 154 155 Args: 156 session_id (str): Session ID. 157 message (Dict): Message dictionary to insert. 158 159 Returns: 160 str: ID of the inserted message. 161 """ 162 session = self._read_file(session_id) 163 entry = { 164 "id": str(uuid.uuid4()), 165 "timestamp": datetime.now(timezone.utc).isoformat(), 166 **message 167 } 168 session["messages"].append(entry) 169 self._save_file(session_id, session) 170 return entry["id"]
Insert a new message into a session.
Args: session_id (str): Session ID. message (Dict): Message dictionary to insert.
Returns: str: ID of the inserted message.
172 def msg_get(self, session_id: str, message_id: str) -> Optional[Dict]: 173 """ 174 Retrieve a specific message from a session. 175 176 Args: 177 session_id (str): Session ID. 178 message_id (str): ID of the message to retrieve. 179 180 Returns: 181 Optional[Dict]: The message if found, else None. 182 """ 183 session = self._read_file(session_id) 184 for msg in session.get("messages", []): 185 if msg.get("id") == message_id: 186 return msg 187 return None
Retrieve a specific message from a session.
Args: session_id (str): Session ID. message_id (str): ID of the message to retrieve.
Returns: Optional[Dict]: The message if found, else None.
189 def msg_index(self, session_id: str, message_id: str) -> Optional[int]: 190 """ 191 Get the index of a message within a session. 192 193 Args: 194 session_id (str): Session ID. 195 message_id (str): Message ID. 196 197 Returns: 198 Optional[int]: Index if found, else None. 199 """ 200 session = self._read_file(session_id) 201 for i, msg in enumerate(session.get("messages", [])): 202 if msg.get("id") == message_id: 203 return i 204 return None
Get the index of a message within a session.
Args: session_id (str): Session ID. message_id (str): Message ID.
Returns: Optional[int]: Index if found, else None.
206 def msg_update(self, session_id: str, message_id: str, new_content: str) -> bool: 207 """ 208 Update the content of a specific message. 209 210 Args: 211 session_id (str): Session ID. 212 message_id (str): Message ID. 213 new_content (str): New content for the message. 214 215 Returns: 216 bool: True if update succeeded, False otherwise. 217 """ 218 session = self._read_file(session_id) 219 for m in session["messages"]: 220 if m.get("id") == message_id: 221 m["content"] = new_content 222 self._save_file(session_id, session) 223 return True 224 return False
Update the content of a specific message.
Args: session_id (str): Session ID. message_id (str): Message ID. new_content (str): New content for the message.
Returns: bool: True if update succeeded, False otherwise.
226 def msg_delete(self, session_id: str, message_id: str) -> bool: 227 """ 228 Delete a message from a session. 229 230 Args: 231 session_id (str): Session ID. 232 message_id (str): Message ID. 233 234 Returns: 235 bool: True if deletion occurred, False otherwise. 236 """ 237 session = self._read_file(session_id) 238 before = len(session["messages"]) 239 session["messages"] = [m for m in session["messages"] if m.get("id") != message_id] 240 self._save_file(session_id, session) 241 return len(session["messages"]) < before
Delete a message from a session.
Args: session_id (str): Session ID. message_id (str): Message ID.
Returns: bool: True if deletion occurred, False otherwise.
247 def branch(self, from_id: str, message_id: str, new_name: str) -> str: 248 """Create a new session by branching from a specific message. 249 250 This method creates a new session that branches from an existing one at a specific 251 message point. The new session inherits all messages up to and including the 252 specified message, then starts fresh from there. 253 254 Args: 255 from_id (str): ID of the source session to branch from. 256 message_id (str): ID of the message to branch at. 257 new_name (str): Name for the new branched session. 258 259 Returns: 260 str: ID of the newly created branched session. 261 262 Raises: 263 ValueError: If the source session or message ID is not found. 264 """ 265 # Get source session 266 source = self._read_file(from_id) 267 if not source: 268 raise ValueError(f"Source session '{from_id}' not found") 269 270 # Find the branch point 271 branch_index = self.msg_index(from_id, message_id) 272 if branch_index is None: 273 raise ValueError(f"Message '{message_id}' not found in session '{from_id}'") 274 275 # Create new session 276 new_id = self.create(new_name, source.get("tags", [])) 277 new_session = self._read_file(new_id) 278 279 # Copy messages up to branch point 280 new_session["messages"] = source["messages"][:branch_index + 1] 281 new_session["parent"] = from_id 282 new_session["branch_point"] = message_id 283 284 # Save and return 285 self._save_file(new_id, new_session) 286 return new_id
Create a new session by branching from a specific message.
This method creates a new session that branches from an existing one at a specific message point. The new session inherits all messages up to and including the specified message, then starts fresh from there.
Args: from_id (str): ID of the source session to branch from. message_id (str): ID of the message to branch at. new_name (str): Name for the new branched session.
Returns: str: ID of the newly created branched session.
Raises: ValueError: If the source session or message ID is not found.
288 def summarize(self, interactor, session_id: str) -> str: 289 """Generate a summary of the session using the provided interactor. 290 291 This method uses the AI interactor to analyze the session content and generate 292 a concise summary. The summary is stored in the session metadata and returned. 293 294 Args: 295 interactor: An AI interactor instance capable of generating summaries. 296 session_id (str): ID of the session to summarize. 297 298 Returns: 299 str: The generated summary text. 300 301 Note: 302 The summary is automatically stored in the session metadata and can be 303 retrieved later using load_full(). 304 """ 305 session = self._read_file(session_id) 306 if not session: 307 return "" 308 309 # Get clean message list 310 messages = self.load(session_id) 311 if not messages: 312 return "" 313 314 # Generate summary 315 summary = interactor.interact( 316 "Summarize this conversation in 2-3 sentences:", 317 tools=False, 318 stream=False, 319 markdown=False 320 ) 321 322 # Store and return 323 session["summary"] = summary 324 self._save_file(session_id, session) 325 return summary
Generate a summary of the session using the provided interactor.
This method uses the AI interactor to analyze the session content and generate a concise summary. The summary is stored in the session metadata and returned.
Args: interactor: An AI interactor instance capable of generating summaries. session_id (str): ID of the session to summarize.
Returns: str: The generated summary text.
Note: The summary is automatically stored in the session metadata and can be retrieved later using load_full().
331 def search(self, query: str, session_id: Optional[str] = None) -> List[Dict]: 332 """Search for messages containing the query text within a session or all sessions. 333 334 This method performs a case-insensitive text search across message content. 335 If a session_id is provided, only searches within that session. Otherwise, 336 searches across all sessions. 337 338 Args: 339 query (str): Text to search for. 340 session_id (Optional[str]): Optional session ID to limit search scope. 341 342 Returns: 343 List[Dict]: List of matching messages with their session context. 344 Each dict contains: 345 - session_id: ID of the containing session 346 - message: The matching message 347 - context: Surrounding messages for context 348 """ 349 results = [] 350 query = query.lower() 351 352 # Determine search scope 353 if session_id: 354 sessions = [(session_id, self._read_file(session_id))] 355 else: 356 sessions = [(f.stem, self._read_file(f.stem)) for f in self.path.glob("*.json")] 357 358 # Search each session 359 for sid, session in sessions: 360 if not session: 361 continue 362 363 messages = session.get("messages", []) 364 for i, msg in enumerate(messages): 365 content = str(msg.get("content", "")).lower() 366 if query in content: 367 # Get context (2 messages before and after) 368 start = max(0, i - 2) 369 end = min(len(messages), i + 3) 370 context = messages[start:end] 371 372 results.append({ 373 "session_id": sid, 374 "message": msg, 375 "context": context 376 }) 377 378 return results
Search for messages containing the query text within a session or all sessions.
This method performs a case-insensitive text search across message content. If a session_id is provided, only searches within that session. Otherwise, searches across all sessions.
Args: query (str): Text to search for. session_id (Optional[str]): Optional session ID to limit search scope.
Returns: List[Dict]: List of matching messages with their session context. Each dict contains: - session_id: ID of the containing session - message: The matching message - context: Surrounding messages for context
380 def search_meta(self, query: str) -> List[Dict]: 381 """Search session metadata (name, tags, summary) for matching sessions. 382 383 This method performs a case-insensitive search across session metadata fields 384 including name, tags, and summary. It returns matching sessions with their 385 full metadata. 386 387 Args: 388 query (str): Text to search for in metadata. 389 390 Returns: 391 List[Dict]: List of matching session metadata dictionaries. 392 Each dict contains: 393 - id: Session ID 394 - name: Session name 395 - created: Creation timestamp 396 - tags: List of tags 397 - summary: Session summary if available 398 """ 399 results = [] 400 query = query.lower() 401 402 for file in self.path.glob("*.json"): 403 try: 404 with open(file, "r") as f: 405 session = json.load(f) 406 407 # Check metadata fields 408 name = str(session.get("name", "")).lower() 409 tags = [str(t).lower() for t in session.get("tags", [])] 410 summary = str(session.get("summary", "")).lower() 411 412 if (query in name or 413 any(query in tag for tag in tags) or 414 query in summary): 415 results.append({ 416 "id": session.get("id"), 417 "name": session.get("name"), 418 "created": session.get("created"), 419 "tags": session.get("tags", []), 420 "summary": session.get("summary") 421 }) 422 except Exception: 423 continue 424 425 return sorted(results, key=lambda x: x["created"], reverse=True)
Search session metadata (name, tags, summary) for matching sessions.
This method performs a case-insensitive search across session metadata fields including name, tags, and summary. It returns matching sessions with their full metadata.
Args: query (str): Text to search for in metadata.
Returns: List[Dict]: List of matching session metadata dictionaries. Each dict contains: - id: Session ID - name: Session name - created: Creation timestamp - tags: List of tags - summary: Session summary if available