interactor
64class Interactor: 65 def __init__( 66 self, 67 base_url: Optional[str] = None, 68 api_key: Optional[str] = None, 69 model: str = "openai:gpt-4o-mini", 70 fallback_model = "ollama:mistral-nemo:latest", 71 tools: Optional[bool] = True, 72 stream: bool = True, 73 quiet: bool = False, 74 context_length: int = 128000, 75 max_retries: int = 3, 76 retry_delay: float = 1.0, 77 log_path: Optional[str] = None, 78 raw: Optional[bool] = False, 79 session_enabled: bool = False, 80 session_id: Optional[str] = None, 81 session_path: Optional[str] = None 82 ): 83 """Initialize the universal AI interaction client. 84 85 Args: 86 base_url: Optional base URL for the API. If None, uses the provider's default URL. 87 api_key: Optional API key. If None, attempts to use environment variables based on provider. 88 model: Model identifier in format "provider:model_name" (e.g., "openai:gpt-4o-mini"). 89 tools: Enable (True) or disable (False) tool calling; None for auto-detection based on model support. 90 stream: Enable (True) or disable (False) streaming responses. 91 context_length: Maximum number of tokens to maintain in conversation history. 92 max_retries: Maximum number of retries for failed API calls. 93 retry_delay: Initial delay (in seconds) for exponential backoff retries. 94 session_enabled: Enable persistent session support. 95 session_id: Optional session ID to load messages from. 96 97 Raises: 98 ValueError: If provider is not supported or API key is missing for non-Ollama providers. 99 """ 100 self.system = "You are a helpful Assistant." 101 self.raw = raw 102 self.quiet = quiet 103 self.logger = logging.getLogger(f"InteractorLogger_{id(self)}") 104 self.logger.setLevel(logging.DEBUG) 105 self.providers = { 106 "openai": { 107 "sdk": "openai", 108 "base_url": "https://api.openai.com/v1", 109 "api_key": api_key or os.getenv("OPENAI_API_KEY") or None 110 }, 111 "ollama": { 112 "sdk": "openai", 113 "base_url": "http://localhost:11434/v1", 114 "api_key": api_key or "ollama" 115 }, 116 "nvidia": { 117 "sdk": "openai", 118 "base_url": "https://integrate.api.nvidia.com/v1", 119 "api_key": api_key or os.getenv("NVIDIA_API_KEY") or None 120 }, 121 "google": { 122 "sdk": "openai", 123 "base_url": "https://generativelanguage.googleapis.com/v1beta/openai", 124 "api_key": api_key or os.getenv("GEMINI_API_KEY") or None 125 }, 126 "anthropic": { 127 "sdk": "anthropic", 128 "base_url": "https://api.anthropic.com/v1", 129 "api_key": api_key or os.getenv("ANTHROPIC_API_KEY") or None 130 }, 131 "mistral": { 132 "sdk": "openai", 133 "base_url": "https://api.mistral.ai/v1", 134 "api_key": api_key or os.getenv("MISTRAL_API_KEY") or None 135 }, 136 "deepseek": { 137 "sdk": "openai", 138 "base_url": "https://api.deepseek.com", 139 "api_key": api_key or os.getenv("DEEPSEEK_API_KEY") or None 140 }, 141 } 142 """ 143 "grok": { 144 "sdk": "grok", 145 "base_url": "https://api.x.ai/v1", 146 "api_key": api_key or os.getenv("GROK_API_KEY") or None 147 } 148 } 149 """ 150 151 # Console log handler (always enabled at WARNING+) 152 if not self.logger.handlers: 153 console_handler = logging.StreamHandler(sys.stdout) 154 console_handler.setLevel(logging.WARNING) 155 console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) 156 self.logger.addHandler(console_handler) 157 158 self._log_enabled = False 159 if log_path: 160 file_handler = logging.FileHandler(log_path) 161 file_handler.setLevel(logging.DEBUG) 162 file_handler.setFormatter(logging.Formatter( 163 "%(asctime)s - %(levelname)s - %(message)s", 164 datefmt="%Y-%m-%d %H:%M:%S" 165 )) 166 self.logger.addHandler(file_handler) 167 self._log_enabled = True 168 169 170 self.token_estimate = 0 171 self.last_token_estimate = 0 172 self.stream = stream 173 self.tools = [] 174 self.session_history = [] 175 self.history = [] 176 self.context_length = context_length 177 self.encoding = None 178 self.max_retries = max_retries 179 self.retry_delay = retry_delay 180 self.reveal_tool = [] 181 self.fallback_model = fallback_model 182 self.sdk = None 183 184 # Session support 185 self.session_enabled = session_enabled 186 self.session_id = session_id 187 self._last_session_id = session_id 188 self.session = Session(directory=session_path) if session_enabled else None 189 190 191 if model is None: 192 model = "openai:gpt-4o-mini" 193 194 # Initialize model + encoding 195 self._setup_client(model, base_url, api_key) 196 self.tools_enabled = self.tools_supported if tools is None else tools and self.tools_supported 197 self._setup_encoding() 198 self.messages_add(role="system", content=self.system) 199 200 # Populate list of available models 201 self.models = self._update_available_models() 202 203 204 def _log(self, message: str, level: str = "info"): 205 """Log a message to the configured logging handlers. 206 207 This internal method handles logging to both console and file handlers 208 if configured. It respects the logging level and only logs if logging 209 is enabled. 210 211 Args: 212 message (str): The message to log 213 level (str): Logging level - one of "debug", "info", "warning", "error" 214 """ 215 if self._log_enabled: 216 getattr(self.logger, level)(message) 217 218 219 def _setup_client( 220 self, 221 model: Optional[str] = None, 222 base_url: Optional[str] = None, 223 api_key: Optional[str] = None 224 ): 225 """Initialize or reconfigure the Interactor for the given model and SDK. 226 227 Ensures idempotent setup, assigns SDK-specific clients and tool handling logic, 228 and normalizes history to match the provider-specific message schema. 229 """ 230 if not model: 231 raise ValueError("Model must be specified as 'provider:model_name'") 232 233 provider, model_name = model.split(":", 1) 234 235 if not hasattr(self, "session_history"): 236 self.session_history = [] 237 238 # Skip setup if nothing has changed (client may not yet exist on first call) 239 if ( 240 hasattr(self, "client") 241 and self.client 242 and self.provider == provider 243 and self.model == model_name 244 and self.base_url == (base_url or self.base_url) 245 ): 246 return 247 248 if provider not in self.providers: 249 raise ValueError(f"Unsupported provider: {provider}. Supported: {list(self.providers.keys())}") 250 251 # Load provider configuration 252 provider_config = self.providers[provider] 253 self.sdk = provider_config.get("sdk", "openai") 254 self.provider = provider 255 self.model = model_name 256 self.base_url = base_url or provider_config["base_url"] 257 effective_api_key = api_key or provider_config["api_key"] 258 259 if not effective_api_key and provider != "ollama": 260 raise ValueError(f"API key not provided and not found in environment for {provider.upper()}_API_KEY") 261 262 # SDK-specific configuration 263 if self.sdk == "openai": 264 self.client = openai.OpenAI(base_url=self.base_url, api_key=effective_api_key) 265 self.async_client = openai.AsyncOpenAI(base_url=self.base_url, api_key=effective_api_key) 266 self.sdk_runner = self._openai_runner 267 self.tool_key = "tool_call_id" 268 269 elif self.sdk == "anthropic": 270 self.client = anthropic.Anthropic(api_key=effective_api_key) 271 self.async_client = anthropic.AsyncAnthropic(api_key=effective_api_key) 272 self.sdk_runner = self._anthropic_runner 273 self.tool_key = "tool_use_id" 274 275 else: 276 raise ValueError(f"Unsupported SDK type: {self.sdk}") 277 278 # Determine tool support 279 self.tools_supported = self._check_tool_support() 280 if not self.tools_supported: 281 self.logger.warning(f"Tool calling not supported for {provider}:{model_name}") 282 283 # Normalize session history to match SDK after any provider/model change 284 self._normalizer(force=True) 285 286 self._log(f"[MODEL] Switched to {provider}:{model_name}") 287 288 289 def _check_tool_support(self) -> bool: 290 """Determine if the current model supports tool calling. 291 292 Returns: 293 bool: True if tools are supported for the active provider/model, False otherwise. 294 """ 295 try: 296 if self.sdk == "openai": 297 response = self.client.chat.completions.create( 298 model=self.model, 299 messages=[{"role": "user", "content": "Test tool support."}], 300 stream=False, 301 tools=[{ 302 "type": "function", 303 "function": { 304 "name": "test_tool", 305 "description": "Check tool support", 306 "parameters": { 307 "type": "object", 308 "properties": { 309 "query": {"type": "string"} 310 }, 311 "required": ["query"] 312 } 313 } 314 }], 315 tool_choice="auto" 316 ) 317 message = response.choices[0].message 318 return bool(message.tool_calls and len(message.tool_calls) > 0) 319 320 elif self.sdk == "anthropic": 321 # For Claude models, we pre-define support based on model ID 322 # Known tool-supporting Claude models 323 claude_tool_models = ["claude-3-opus", "claude-3-sonnet", "claude-3-haiku", 324 "claude-3.5-sonnet", "claude-3.7-sonnet"] 325 326 # Check if the current model supports tools 327 for supported_model in claude_tool_models: 328 if supported_model in self.model.lower(): 329 self._log(f"[TOOLS] Anthropic model {self.model} is known to support tools") 330 return True 331 332 # If not explicitly supported, try to test 333 try: 334 _ = self.client.messages.create( 335 model=self.model, 336 messages=[{"role": "user", "content": "What's the weather?"}], 337 tools=[{ 338 "name": "test_tool", 339 "description": "Check tool support", 340 "input_schema": { 341 "type": "object", 342 "properties": { 343 "query": {"type": "string"} 344 }, 345 "required": ["query"] 346 } 347 }], 348 max_tokens=10 349 ) 350 return True 351 except anthropic.BadRequestError as e: 352 error_msg = str(e).lower() 353 if "tool" in error_msg and "not supported" in error_msg: 354 self._log(f"[TOOLS] Anthropic model {self.model} does not support tools: {e}") 355 return False 356 if "not a supported tool field" in error_msg: 357 self._log(f"[TOOLS] Anthropic API rejected tool format: {e}") 358 return False 359 raise 360 except Exception as e: 361 self._log(f"[TOOLS] Unexpected error testing tool support: {e}", level="error") 362 return False 363 364 else: 365 self.logger.warning(f"Tool support check not implemented for SDK '{self.sdk}'") 366 return False 367 368 except Exception as e: 369 self.logger.error(f"Tool support check failed for {self.provider}:{self.model} — {e}") 370 return False 371 372 373 def add_function( 374 self, 375 external_callable: Callable, 376 name: Optional[str] = None, 377 description: Optional[str] = None, 378 override: bool = False, 379 disabled: bool = False, 380 schema_extensions: Optional[Dict[str, Any]] = None 381 ): 382 """ 383 Register a function for LLM tool calling with full type hints and metadata. 384 385 Args: 386 external_callable (Callable): The function to register. 387 name (Optional[str]): Optional custom name. Defaults to function's __name__. 388 description (Optional[str]): Optional custom description. Defaults to first line of docstring. 389 override (bool): If True, replaces an existing tool with the same name. 390 disabled (bool): If True, registers the function in a disabled state. 391 schema_extensions (Optional[Dict[str, Any]]): Optional dictionary mapping parameter names to 392 schema extensions that override or add to the auto-generated schema. 393 394 Raises: 395 ValueError: If the callable is invalid or duplicate name found without override. 396 397 Example: 398 interactor.add_function( 399 my_tool, 400 override=True, 401 disabled=False, 402 schema_extensions={ 403 "param1": {"minimum": 0, "maximum": 100}, 404 "param2": {"format": "email"} 405 } 406 ) 407 """ 408 def _python_type_to_schema(ptype: Any) -> dict: 409 """Convert a Python type annotation to OpenAI-compatible JSON Schema.""" 410 # Handle None case 411 if ptype is None: 412 return {"type": "null"} 413 414 # Get the origin and arguments of the type 415 origin = get_origin(ptype) 416 args = get_args(ptype) 417 418 # Handle Union types (including Optional) 419 if origin is Union: 420 # Check for Optional (Union with None) 421 none_type = type(None) 422 if none_type in args: 423 non_none = [a for a in args if a is not none_type] 424 if len(non_none) == 1: 425 inner = _python_type_to_schema(non_none[0]) 426 inner_copy = inner.copy() 427 inner_copy["nullable"] = True 428 return inner_copy 429 # Multiple types excluding None 430 types = [_python_type_to_schema(a) for a in non_none] 431 return {"anyOf": types, "nullable": True} 432 # Regular Union without None 433 return {"anyOf": [_python_type_to_schema(a) for a in args]} 434 435 # Handle List and similar container types 436 if origin in (list, List): 437 item_type = args[0] if args else Any 438 if item_type is Any: 439 return {"type": "array"} 440 return {"type": "array", "items": _python_type_to_schema(item_type)} 441 442 # Handle Dict types with typing info 443 if origin in (dict, Dict): 444 if not args or len(args) != 2: 445 return {"type": "object"} 446 447 key_type, val_type = args 448 # We can only really use val_type in JSON Schema 449 if val_type is not Any and val_type is not object: 450 return { 451 "type": "object", 452 "additionalProperties": _python_type_to_schema(val_type) 453 } 454 return {"type": "object"} 455 456 # Handle Literal types for enums 457 if origin is Literal: 458 values = args 459 # Try to determine type from values 460 if all(isinstance(v, str) for v in values): 461 return {"type": "string", "enum": list(values)} 462 elif all(isinstance(v, bool) for v in values): 463 return {"type": "boolean", "enum": list(values)} 464 elif all(isinstance(v, (int, float)) for v in values): 465 return {"type": "number", "enum": list(values)} 466 else: 467 # Mixed types, use anyOf 468 return {"anyOf": [{"type": _get_json_type(v), "enum": [v]} for v in values]} 469 470 # Handle basic types 471 if ptype is str: 472 return {"type": "string"} 473 if ptype is int: 474 return {"type": "integer"} 475 if ptype is float: 476 return {"type": "number"} 477 if ptype is bool: 478 return {"type": "boolean"} 479 480 # Handle common datetime types 481 if ptype is datetime: 482 return {"type": "string", "format": "date-time"} 483 if ptype is date: 484 return {"type": "string", "format": "date"} 485 486 # Handle UUID 487 if ptype is uuid.UUID: 488 return {"type": "string", "format": "uuid"} 489 490 # Default to object for any other types 491 return {"type": "object"} 492 493 def _get_json_type(value): 494 """Get the JSON Schema type name for a Python value. 495 496 This helper function maps Python types to their corresponding 497 JSON Schema type names. It handles basic types and provides 498 sensible defaults for complex types. 499 500 Args: 501 value: The Python value to get the JSON type for 502 503 Returns: 504 str: The JSON Schema type name ('string', 'number', 'boolean', 505 'array', 'object', or 'object' as default) 506 """ 507 if isinstance(value, str): 508 return "string" 509 elif isinstance(value, bool): 510 return "boolean" 511 elif isinstance(value, int) or isinstance(value, float): 512 return "number" 513 elif isinstance(value, list): 514 return "array" 515 elif isinstance(value, dict): 516 return "object" 517 else: 518 return "object" # Default 519 520 def _parse_param_docs(docstring: str) -> dict: 521 """Extract parameter descriptions from a docstring.""" 522 if not docstring: 523 return {} 524 525 lines = docstring.splitlines() 526 param_docs = {} 527 current_param = None 528 in_params = False 529 530 # Regular expressions for finding parameter sections and param lines 531 param_section_re = re.compile(r"^(Args|Parameters):\s*$") 532 param_line_re = re.compile(r"^\s{4}(\w+)\s*(?:\([^\)]*\))?:\s*(.*)") 533 534 for line in lines: 535 # Check if we're entering the parameters section 536 if param_section_re.match(line.strip()): 537 in_params = True 538 continue 539 540 if in_params: 541 # Skip empty lines 542 if not line.strip(): 543 continue 544 545 # Check for a parameter definition line 546 match = param_line_re.match(line) 547 if match: 548 current_param = match.group(1) 549 param_docs[current_param] = match.group(2).strip() 550 # Check for continuation of a parameter description 551 elif current_param and line.startswith(" " * 8): 552 param_docs[current_param] += " " + line.strip() 553 # If we see a line that doesn't match our patterns, we're out of the params section 554 else: 555 current_param = None 556 557 return param_docs 558 559 # Start of main function logic 560 561 # Skip if tools are disabled 562 if not self.tools_enabled: 563 return 564 565 # Validate input callable 566 if not external_callable: 567 raise ValueError("A valid external callable must be provided.") 568 569 # Set function name, either from parameter or from callable's __name__ 570 function_name = name or external_callable.__name__ 571 572 # Try to get docstring and extract description 573 try: 574 docstring = inspect.getdoc(external_callable) 575 description = description or (docstring.split("\n")[0].strip() if docstring else "No description provided.") 576 except Exception as e: 577 self._log(f"[TOOL] Warning: Could not extract docstring from {function_name}: {e}", level="warning") 578 docstring = "" 579 description = description or "No description provided." 580 581 # Extract parameter documentation from docstring 582 param_docs = _parse_param_docs(docstring) 583 584 # Handle conflicts with existing functions 585 if override: 586 self.delete_function(function_name) 587 elif any(t["function"]["name"] == function_name for t in self.tools): 588 raise ValueError(f"Function '{function_name}' is already registered. Use override=True to replace.") 589 590 # Try to get function signature for parameter info 591 try: 592 signature = inspect.signature(external_callable) 593 except (ValueError, TypeError) as e: 594 raise ValueError(f"Cannot inspect callable '{function_name}': {e}") 595 596 # Process parameters to build schema 597 properties = {} 598 required = [] 599 600 for param_name, param in signature.parameters.items(): 601 # Skip self, cls parameters for instance/class methods 602 if param_name in ("self", "cls") and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: 603 continue 604 605 # Get parameter annotation, defaulting to Any 606 annotation = param.annotation if param.annotation != inspect.Parameter.empty else Any 607 608 try: 609 # Convert Python type to JSON Schema 610 schema = _python_type_to_schema(annotation) 611 612 # Add description from docstring or create a default one 613 schema["description"] = param_docs.get(param_name, f"{param_name} parameter") 614 615 # Add to properties 616 properties[param_name] = schema 617 618 # If no default value is provided, parameter is required 619 if param.default == inspect.Parameter.empty: 620 required.append(param_name) 621 self._log(f"[TOOL] Parameter '{param_name}' is required", level="debug") 622 else: 623 self._log(f"[TOOL] Parameter '{param_name}' has default value: {param.default}", level="debug") 624 625 except Exception as e: 626 self._log(f"[TOOL] Error processing parameter {param_name} for {function_name}: {e}", level="error") 627 # Add a basic object schema as fallback 628 properties[param_name] = { 629 "type": "string", # Default to string instead of object for better compatibility 630 "description": f"{param_name} parameter (type conversion failed)" 631 } 632 633 # For parameters with no default value, mark as required even if processing failed 634 if param.default == inspect.Parameter.empty: 635 required.append(param_name) 636 self._log(f"[TOOL] Parameter '{param_name}' marked as required despite conversion failure", level="debug") 637 638 # Apply schema extensions if provided 639 if schema_extensions: 640 for param_name, extensions in schema_extensions.items(): 641 if param_name in properties: 642 properties[param_name].update(extensions) 643 644 # Create parameters object with proper placement of 'required' field 645 parameters = { 646 "type": "object", 647 "properties": properties, 648 } 649 650 # Only add required field if there are required parameters 651 if required: 652 parameters["required"] = required 653 654 # Build the final tool specification 655 tool_spec = { 656 "type": "function", 657 "function": { 658 "name": function_name, 659 "description": description, 660 "parameters": parameters 661 } 662 } 663 664 # Set disabled flag if requested 665 if disabled: 666 tool_spec["function"]["disabled"] = True 667 668 # Add to tools list 669 self.tools.append(tool_spec) 670 671 # Make the function available as an attribute on the instance 672 setattr(self, function_name, external_callable) 673 674 # Log the registration with detailed information 675 self._log(f"[TOOL] Registered function '{function_name}' with {len(properties)} parameters", level="info") 676 if required: 677 self._log(f"[TOOL] Required parameters: {required}", level="info") 678 679 return function_name # Return the name for reference 680 681 682 def disable_function(self, name: str) -> bool: 683 """ 684 Disable a registered tool function by name. 685 686 This marks the function as inactive for tool calling without removing it from the internal registry. 687 The function remains visible in the tool listing but is skipped during tool selection by the LLM. 688 689 Args: 690 name (str): The name of the function to disable. 691 692 Returns: 693 bool: True if the function was found and disabled, False otherwise. 694 695 Example: 696 interactor.disable_function("extract_text") 697 """ 698 for tool in self.tools: 699 if tool["function"]["name"] == name: 700 tool["function"]["disabled"] = True 701 return True 702 return False 703 704 705 def enable_function(self, name: str) -> bool: 706 """ 707 Re-enable a previously disabled tool function by name. 708 709 This removes the 'disabled' flag from a tool function, making it available again for LLM use. 710 711 Args: 712 name (str): The name of the function to enable. 713 714 Returns: 715 bool: True if the function was found and enabled, False otherwise. 716 717 Example: 718 interactor.enable_function("extract_text") 719 """ 720 for tool in self.tools: 721 if tool["function"]["name"] == name: 722 tool["function"].pop("disabled", None) 723 return True 724 return False 725 726 727 def delete_function(self, name: str) -> bool: 728 """ 729 Permanently remove a registered tool function from the Interactor. 730 731 This deletes both the tool metadata and the callable attribute, making it fully inaccessible 732 from the active session. Useful for dynamically trimming the toolset. 733 734 Args: 735 name (str): The name of the function to delete. 736 737 Returns: 738 bool: True if the function was found and removed, False otherwise. 739 740 Example: 741 interactor.delete_function("extract_text") 742 """ 743 before = len(self.tools) 744 self.tools = [tool for tool in self.tools if tool["function"]["name"] != name] 745 if hasattr(self, name): 746 delattr(self, name) 747 return len(self.tools) < before 748 749 750 def list_functions(self) -> List[Dict[str, Any]]: 751 """Get the list of registered functions for tool calling. 752 753 Returns: 754 List[Dict[str, Any]]: List of registered functions. 755 """ 756 return self.tools 757 758 759 def list_models(self): 760 """list available models""" 761 return self.models 762 763 def _update_available_models( 764 self, 765 providers: Optional[Union[str, List[str]]] = None, 766 filter: Optional[str] = None 767 ) -> List[str]: 768 """Retrieve available models from configured providers. 769 770 Args: 771 providers: Provider name or list of provider names. If None, all are queried. 772 filter: Optional regex to filter model names. 773 774 Returns: 775 List[str]: Sorted list of "provider:model_id" strings. 776 """ 777 models = [] 778 779 if providers is None: 780 providers_to_list = self.providers 781 elif isinstance(providers, str): 782 providers_to_list = {providers: self.providers.get(providers)} 783 elif isinstance(providers, list): 784 providers_to_list = {p: self.providers.get(p) for p in providers} 785 else: 786 return [] 787 788 invalid_providers = [p for p in providers_to_list if p not in self.providers or self.providers[p] is None] 789 if invalid_providers: 790 self.logger.error(f"Invalid providers: {invalid_providers}") 791 return [] 792 793 regex_pattern = None 794 if filter: 795 try: 796 regex_pattern = re.compile(filter, re.IGNORECASE) 797 except re.error as e: 798 self.logger.error(f"Invalid regex pattern: {e}") 799 return [] 800 801 for provider_name, config in providers_to_list.items(): 802 sdk = config.get("sdk", "openai") 803 base_url = config.get("base_url") 804 api_key = config.get("api_key") 805 806 try: 807 if sdk == "openai": 808 client = openai.OpenAI(api_key=api_key, base_url=base_url) 809 response = client.models.list() 810 for model in response.data: 811 model_id = f"{provider_name}:{model.id}" 812 if not regex_pattern or regex_pattern.search(model_id): 813 models.append(model_id) 814 815 elif sdk == "anthropic": 816 client = Anthropic(api_key=api_key) 817 response = client.models.list() 818 for model in response: 819 model_id = f"{provider_name}:{model.id}" 820 if not regex_pattern or regex_pattern.search(model_id): 821 models.append(model_id) 822 else: 823 self.logger.warning(f"SDK '{sdk}' for provider '{provider_name}' is not supported by list_models()") 824 825 except Exception as e: 826 self.logger.error(f"Failed to list models for {provider_name}: {e}") 827 828 return sorted(models, key=str.lower) 829 830 831 async def _retry_with_backoff(self, func: Callable, *args, **kwargs): 832 """Execute a function with exponential backoff retry logic. 833 834 This method implements a robust retry mechanism for API calls with 835 exponential backoff. It handles rate limits, connection errors, and 836 other transient failures. If all retries fail, it will attempt to 837 switch to a fallback model if configured. 838 839 Args: 840 func (Callable): The async function to execute 841 *args: Positional arguments to pass to the function 842 **kwargs: Keyword arguments to pass to the function 843 844 Returns: 845 The result of the function call if successful 846 847 Raises: 848 Exception: If all retries fail and no fallback model is available 849 """ 850 for attempt in range(self.max_retries + 1): 851 try: 852 return await func(*args, **kwargs) 853 854 except (RateLimitError, APIConnectionError, aiohttp.ClientError) as e: 855 if attempt == self.max_retries: 856 model_key = f"{self.provider}:{self.model}" 857 if self.fallback_model and model_key != self.fallback_model: 858 print(f"[yellow]Model '{model_key}' failed. Switching to fallback: {self.fallback_model}[/yellow]") 859 self._setup_client(self.fallback_model) 860 self._setup_encoding() 861 self._normalizer() 862 return await func(*args, **kwargs) # retry once with fallback model 863 else: 864 self.logger.error(f"All {self.max_retries} retries failed: {e}") 865 raise 866 867 delay = self.retry_delay * (2 ** attempt) 868 self.logger.warning(f"Retry {attempt + 1}/{self.max_retries} after {delay}s due to {e}") 869 self._log(f"[RETRY] Attempt {attempt + 1}/{self.max_retries} failed: {e}", level="warning") 870 await asyncio.sleep(delay) 871 872 except OpenAIError as e: 873 self.logger.error(f"OpenAI error: {e}") 874 raise 875 876 except Exception as e: 877 self.logger.error(f"Unexpected error: {e}") 878 raise 879 880 881 def interact( 882 self, 883 user_input: Optional[str], 884 quiet: bool = False, 885 tools: bool = True, 886 stream: bool = True, 887 markdown: bool = False, 888 model: Optional[str] = None, 889 output_callback: Optional[Callable[[str], None]] = None, 890 session_id: Optional[str] = None, 891 raw: Optional[bool] = None, 892 tool_suppress: bool = True, 893 timeout: float = 60.0 894 ) -> Union[Optional[str], "TokenStream"]: 895 """Main universal gateway for all LLM interaction. 896 897 This function serves as the single entry point for all interactions with the language model. 898 When `raw=False` (default), it handles the interaction internally and returns the full response. 899 When `raw=True`, it returns a context manager that yields chunks of the response for custom handling. 900 901 Args: 902 user_input: Text input from the user. 903 quiet: If True, don't print status info or progress. 904 tools: Enable (True) or disable (False) tool calling. 905 stream: Enable (True) or disable (False) streaming responses. 906 markdown: If True, renders content as markdown. 907 model: Optional model override. 908 output_callback: Optional callback to handle the output. 909 session_id: Optional session ID to load messages from. 910 raw: If True, return a context manager instead of handling the interaction internally. 911 If None, use the class-level setting from __init__. 912 tool_suppress: If True and raw=True, filter out tool-related status messages. 913 timeout: Maximum time in seconds to wait for the stream to complete when raw=True. 914 915 Returns: 916 If raw=False: The complete response from the model as a string, or None if there was an error. 917 If raw=True: A context manager that yields chunks of the response as they arrive. 918 919 Example with default mode: 920 response = ai.interact("Tell me a joke") 921 922 Example with raw mode: 923 with ai.interact("Tell me a joke", raw=True) as stream: 924 for chunk in stream: 925 print(chunk, end="", flush=True) 926 """ 927 if not user_input: 928 return None 929 930 if quiet or self.quiet: 931 markdown = False 932 stream = False 933 934 # Determine if we should use raw mode 935 # If raw parameter is explicitly provided, use that; otherwise use class setting 936 use_raw = self.raw if raw is None else raw 937 938 # If raw mode is requested, delegate to interact_raw 939 if use_raw: 940 return self._interact_raw( 941 user_input=user_input, 942 tools=tools, 943 model=model, 944 session_id=session_id, 945 tool_suppress=tool_suppress, 946 timeout=timeout 947 ) 948 949 # Setup model if specified 950 if model: 951 self._setup_client(model) 952 self._setup_encoding() 953 954 # Session handling 955 if self.session_enabled and session_id: 956 self.session_id = session_id 957 self.session_load(session_id) 958 959 # Add user message using messages_add 960 self.messages_add(role="user", content=user_input) 961 962 # Log token count estimate 963 token_count = self._count_tokens(self.history) 964 if not quiet: 965 print(f"[dim]Estimated tokens in context: {token_count} / {self.context_length}[/dim]") 966 967 # Make sure we have enough context space 968 if token_count > self.context_length: 969 if self._cycle_messages(): 970 if not quiet: 971 print("[red]Context window exceeded. Cannot proceed.[/red]") 972 return None 973 974 # Log user input 975 self._log(f"[USER] {user_input}") 976 977 # Handle the actual interaction with complete streaming for all responses 978 result = asyncio.run(self._interact_async_core( 979 user_input=user_input, 980 quiet=quiet, 981 tools=tools, 982 stream=stream, 983 markdown=markdown, 984 output_callback=output_callback 985 )) 986 987 # Log completion for this interaction 988 self._log(f"[INTERACTION] Completed with {len(self.history)} total messages") 989 990 return result 991 992 993 def _interact_raw( 994 self, 995 user_input: Optional[str], 996 tools: bool = True, 997 model: Optional[str] = None, 998 session_id: Optional[str] = None, 999 tool_suppress: bool = True, 1000 timeout: float = 60.0 1001 ): 1002 """ 1003 Low-level function that returns a raw stream of tokens from the model. 1004 1005 This method works as a context manager that yields a generator of streaming tokens. 1006 The caller is responsible for handling the output stream. Typically, this is used 1007 indirectly through interact() with raw=True. 1008 1009 Args: 1010 user_input: Text input from the user. 1011 tools: Enable (True) or disable (False) tool calling. 1012 model: Optional model override. 1013 session_id: Optional session ID to load messages from. 1014 tool_suppress: If True, filter out tool-related status messages. 1015 timeout: Maximum time in seconds to wait for the stream to complete. 1016 1017 Returns: 1018 A context manager that yields a stream of tokens. 1019 1020 Example: 1021 with ai.interact_raw("Hello world") as stream: 1022 for chunk in stream: 1023 print(chunk, end="", flush=True) 1024 """ 1025 if not user_input: 1026 return None 1027 1028 # Setup model if specified 1029 if model: 1030 self._setup_client(model) 1031 self._setup_encoding() 1032 1033 # Session handling 1034 if self.session_enabled and session_id: 1035 self.session_id = session_id 1036 self.session_load(session_id) 1037 1038 # Add user message 1039 self.messages_add(role="user", content=user_input) 1040 1041 # Log token count estimate 1042 token_count = self._count_tokens(self.history) 1043 self._log(f"[STREAM] Estimated tokens in context: {token_count} / {self.context_length}") 1044 1045 # Make sure we have enough context space 1046 if token_count > self.context_length: 1047 if self._cycle_messages(): 1048 self._log("[STREAM] Context window exceeded. Cannot proceed.", level="error") 1049 return None 1050 1051 # Log user input 1052 self._log(f"[USER] {user_input}") 1053 1054 # Create a token stream class using a thread-safe queue 1055 class TokenStream: 1056 def __init__(self, interactor, user_input, tools, tool_suppress, timeout): 1057 """Initialize a new TokenStream instance. 1058 1059 This class provides a context manager for streaming token responses 1060 from the AI model. It handles asynchronous token delivery, tool call 1061 suppression, and timeout management. 1062 1063 Args: 1064 interactor: The parent Interactor instance 1065 user_input: The user's input text 1066 tools: Whether tool calling is enabled 1067 tool_suppress: Whether to suppress tool-related status messages 1068 timeout: Maximum time in seconds to wait for stream completion 1069 """ 1070 self.interactor = interactor 1071 self.user_input = user_input 1072 self.tools = tools 1073 self.tool_suppress = tool_suppress 1074 self.timeout = timeout 1075 self.token_queue = queue.Queue() 1076 self.thread = None 1077 self.result = None 1078 self.error = None 1079 self.completed = False 1080 1081 def __enter__(self): 1082 """Enter the context manager and start the streaming process. 1083 1084 This method initializes the streaming worker thread and returns 1085 self for iteration. The worker thread handles the actual API 1086 communication and token delivery. 1087 1088 Returns: 1089 TokenStream: Self for iteration 1090 """ 1091 # Start the thread for async interaction 1092 def stream_worker(): 1093 """Worker thread that handles the streaming interaction. 1094 1095 This internal function runs in a separate thread to handle 1096 the asynchronous API communication and token delivery. 1097 """ 1098 # Define output callback to put tokens in queue 1099 def callback(text): 1100 """Process and queue incoming text tokens. 1101 1102 This internal function handles incoming text chunks, 1103 optionally filtering tool-related messages, and adds 1104 them to the token queue. 1105 1106 Args: 1107 text: The text chunk to process and queue 1108 """ 1109 # Filter out tool messages if requested 1110 if self.tool_suppress: 1111 try: 1112 # Check if this is a tool status message (JSON format) 1113 data = json.loads(text) 1114 if isinstance(data, dict) and data.get("type") == "tool_call": 1115 # Skip this message 1116 return 1117 except (json.JSONDecodeError, TypeError): 1118 # Not JSON or not a dict, continue normally 1119 pass 1120 1121 # Add to queue 1122 self.token_queue.put(text) 1123 1124 # Run the interaction in a new event loop 1125 loop = asyncio.new_event_loop() 1126 asyncio.set_event_loop(loop) 1127 1128 try: 1129 # Run the interaction 1130 self.result = loop.run_until_complete( 1131 self.interactor._interact_async_core( 1132 user_input=self.user_input, 1133 quiet=True, 1134 tools=self.tools, 1135 stream=True, 1136 markdown=False, 1137 output_callback=callback 1138 ) 1139 ) 1140 # Signal successful completion 1141 self.completed = True 1142 except Exception as e: 1143 self.error = str(e) 1144 self.interactor.logger.error(f"Streaming error: {traceback.format_exc()}") 1145 # Add error information to the queue if we haven't yielded anything yet 1146 if self.token_queue.empty(): 1147 self.token_queue.put(f"Error: {str(e)}") 1148 finally: 1149 # Signal end of stream regardless of success/failure 1150 self.token_queue.put(None) 1151 loop.close() 1152 1153 # Start the worker thread 1154 self.thread = threading.Thread(target=stream_worker) 1155 self.thread.daemon = True 1156 self.thread.start() 1157 1158 # Return self for iteration 1159 return self 1160 1161 def __iter__(self): 1162 """Return self as an iterator. 1163 1164 Returns: 1165 TokenStream: Self for iteration 1166 """ 1167 return self 1168 1169 def __next__(self): 1170 """Get the next token from the stream. 1171 1172 This method implements the iterator protocol, retrieving the next 1173 token from the queue with timeout handling. 1174 1175 Returns: 1176 str: The next token from the stream 1177 1178 Raises: 1179 StopIteration: When the stream is complete or times out 1180 """ 1181 # Get next token from queue with timeout to prevent hanging 1182 try: 1183 token = self.token_queue.get(timeout=self.timeout) 1184 if token is None: 1185 # End of stream 1186 raise StopIteration 1187 return token 1188 except queue.Empty: 1189 # Timeout reached 1190 self.interactor.logger.warning(f"Stream timeout after {self.timeout}s") 1191 if not self.completed and not self.error: 1192 # Clean up the thread - it might be stuck 1193 if self.thread and self.thread.is_alive(): 1194 # We can't forcibly terminate a thread in Python, 1195 # but we can report the issue 1196 self.interactor.logger.error("Stream worker thread is hung") 1197 raise StopIteration 1198 1199 def __exit__(self, exc_type, exc_val, exc_tb): 1200 """Exit the context manager and clean up resources. 1201 1202 This method handles cleanup when the context manager is exited, 1203 including thread cleanup and message history updates. 1204 1205 Args: 1206 exc_type: The exception type if an exception was raised 1207 exc_val: The exception value if an exception was raised 1208 exc_tb: The exception traceback if an exception was raised 1209 1210 Returns: 1211 bool: False to not suppress any exceptions 1212 """ 1213 # Clean up resources 1214 if self.thread and self.thread.is_alive(): 1215 self.thread.join(timeout=2.0) 1216 1217 # Add messages to history if successful 1218 if self.completed and self.result and not exc_type: 1219 if isinstance(self.result, str) and self.result != "No response.": 1220 # If we had a successful completion, ensure the result is in the history 1221 last_msg = self.interactor.history[-1] if self.interactor.history else None 1222 if not last_msg or last_msg.get("role") != "assistant" or last_msg.get("content") != self.result: 1223 # Add a clean assistant message to history if not already there 1224 self.interactor.messages_add(role="assistant", content=self.result) 1225 1226 # If there was an error in the stream processing, log it 1227 if self.error: 1228 self.interactor.logger.error(f"Stream processing error: {self.error}") 1229 1230 return False # Don't suppress exceptions 1231 1232 return TokenStream(self, user_input, tools, tool_suppress, timeout) 1233 1234 1235 async def _interact_async_core( 1236 self, 1237 user_input: str, 1238 quiet: bool = False, 1239 tools: bool = True, 1240 stream: bool = True, 1241 markdown: bool = False, 1242 output_callback: Optional[Callable] = None 1243 ) -> str: 1244 """Main SDK-agnostic async execution pipeline with tool call looping support.""" 1245 # Prepare display handler 1246 live = Live(console=console, refresh_per_second=100) if markdown and stream else None 1247 if live: 1248 live.start() 1249 1250 # Initialize variables for iteration tracking 1251 full_content = "" 1252 max_iterations = 5 # Prevent infinite loops 1253 iterations = 0 1254 1255 # Main interaction loop - continues until no more tool calls or max iterations reached 1256 while iterations < max_iterations: 1257 iterations += 1 1258 1259 try: 1260 # Execute the appropriate SDK runner - history is already normalized 1261 response_data = await self.sdk_runner( 1262 model=self.model, 1263 messages=self.history, 1264 stream=stream, 1265 markdown=markdown, 1266 quiet=quiet if iterations == 1 else False, 1267 live=live, 1268 output_callback=output_callback 1269 ) 1270 1271 # Extract response data 1272 content = response_data.get("content", "") 1273 tool_calls = response_data.get("tool_calls", []) 1274 1275 # Log the response data for debugging 1276 self._log(f"[ITERATION {iterations}] Content: {len(content)} chars, Tool calls: {len(tool_calls)}") 1277 1278 # Add content to full response 1279 if iterations == 1: 1280 full_content = content 1281 elif content: 1282 if full_content and content: 1283 full_content += f"\n{content}" 1284 else: 1285 full_content = content 1286 1287 # Add assistant message with or without tool calls 1288 if tool_calls: 1289 # Process each tool call 1290 for call in tool_calls: 1291 # Add assistant message with tool call 1292 tool_info = { 1293 "id": call["id"], 1294 "name": call["function"]["name"], 1295 "arguments": call["function"]["arguments"] 1296 } 1297 1298 # Add the assistant message with tool call 1299 self.messages_add( 1300 role="assistant", 1301 content=content if len(tool_calls) == 1 else "", 1302 tool_info=tool_info 1303 ) 1304 1305 # Execute the tool 1306 call_name = call["function"]["name"] 1307 call_args = call["function"]["arguments"] 1308 call_id = call["id"] 1309 1310 # Stop Rich Live while executing tool calls 1311 live_was_active = True 1312 if live and live.is_started: 1313 live_was_active = True 1314 live.stop() 1315 1316 result = await self._handle_tool_call_async( 1317 function_name=call_name, 1318 function_arguments=call_args, 1319 tool_call_id=call_id, 1320 quiet=quiet, 1321 safe=False, 1322 output_callback=output_callback 1323 ) 1324 1325 # Restart live display if it was active before 1326 if live_was_active and live: 1327 live.start() 1328 1329 # Add tool result message 1330 tool_result_info = { 1331 "id": call_id, 1332 "result": result 1333 } 1334 1335 self.messages_add( 1336 role="tool", 1337 content=result, 1338 tool_info=tool_result_info 1339 ) 1340 else: 1341 # Simple assistant response without tool calls 1342 self.messages_add(role="assistant", content=content) 1343 break # No more tools to process, we're done 1344 1345 # Reset live display if needed 1346 if stream and live: 1347 live.stop() 1348 live = Live(console=console, refresh_per_second=100) 1349 live.start() 1350 1351 except Exception as e: 1352 self.logger.error(f"[{self.sdk.upper()} ERROR] {str(e)}") 1353 self._log(f"[ERROR] Error in interaction loop: {str(e)}", level="error") 1354 if live: 1355 live.stop() 1356 return f"Error: {str(e)}" 1357 1358 # Clean up display 1359 if live: 1360 live.stop() 1361 1362 return full_content or None 1363 1364 1365 async def _openai_runner( 1366 self, 1367 *, 1368 model, 1369 messages, 1370 stream, 1371 markdown=False, 1372 quiet=False, 1373 live=None, 1374 output_callback=None 1375 ): 1376 """Handle OpenAI-specific API interactions and response processing.""" 1377 # Log what we're sending for debugging 1378 self._log(f"[OPENAI REQUEST] Sending request to {model} with {len(self.history)} messages", level="debug") 1379 1380 # Prepare API parameters - history is already normalized by _normalizer 1381 params = { 1382 "model": model, 1383 "messages": self.history, 1384 "stream": stream, 1385 } 1386 1387 # Add tools if enabled 1388 if self.tools_enabled and self.tools_supported: 1389 enabled_tools = self._get_enabled_tools() 1390 if enabled_tools: 1391 params["tools"] = enabled_tools 1392 params["tool_choice"] = "auto" 1393 1394 # Call API with retry handling 1395 try: 1396 response = await self._retry_with_backoff( 1397 self.async_client.chat.completions.create, 1398 **params 1399 ) 1400 except Exception: 1401 self.logger.error(f"[OPENAI ERROR RUNNER]: {traceback.format_exc()}") 1402 raise 1403 1404 assistant_content = "" 1405 tool_calls_dict = {} 1406 1407 # Process streaming response 1408 if stream and hasattr(response, "__aiter__"): 1409 async for chunk in response: 1410 delta = getattr(chunk.choices[0], "delta", None) 1411 1412 # Handle content chunks 1413 if hasattr(delta, "content") and delta.content is not None: 1414 text = delta.content 1415 assistant_content += text 1416 if output_callback: 1417 output_callback(text) 1418 elif live: 1419 live.update(Markdown(assistant_content)) 1420 elif not markdown: 1421 print(text, end="") 1422 1423 # Process tool calls 1424 if hasattr(delta, "tool_calls") and delta.tool_calls: 1425 for tool_call_delta in delta.tool_calls: 1426 index = tool_call_delta.index 1427 if index not in tool_calls_dict: 1428 tool_calls_dict[index] = { 1429 "id": tool_call_delta.id if hasattr(tool_call_delta, "id") else None, 1430 "function": {"name": "", "arguments": ""} 1431 } 1432 1433 function = getattr(tool_call_delta, "function", None) 1434 if function: 1435 name = getattr(function, "name", None) 1436 args = getattr(function, "arguments", "") 1437 if name: 1438 tool_calls_dict[index]["function"]["name"] = name 1439 if args: 1440 tool_calls_dict[index]["function"]["arguments"] += args 1441 if tool_call_delta.id and not tool_calls_dict[index]["id"]: 1442 tool_calls_dict[index]["id"] = tool_call_delta.id 1443 1444 # Make sure the ID is set regardless 1445 if hasattr(tool_call_delta, "id") and tool_call_delta.id and not tool_calls_dict[index]["id"]: 1446 tool_calls_dict[index]["id"] = tool_call_delta.id 1447 1448 if not output_callback and not markdown and not quiet: 1449 print() 1450 1451 # Process non-streaming response 1452 else: 1453 message = response.choices[0].message 1454 assistant_content = message.content or "" 1455 1456 if hasattr(message, "tool_calls") and message.tool_calls: 1457 for i, tool_call in enumerate(message.tool_calls): 1458 tool_calls_dict[i] = { 1459 "id": tool_call.id, 1460 "function": { 1461 "name": tool_call.function.name, 1462 "arguments": tool_call.function.arguments 1463 } 1464 } 1465 1466 if output_callback: 1467 output_callback(assistant_content) 1468 elif not quiet: 1469 print(assistant_content) 1470 1471 # Log tool calls for debugging 1472 if tool_calls_dict: 1473 self._log(f"[OPENAI TOOL CALLS] Found {len(tool_calls_dict)} tool calls", level="debug") 1474 for idx, call in tool_calls_dict.items(): 1475 self._log(f"[OPENAI TOOL CALL {idx}] {call['function']['name']} with ID {call['id']}", level="debug") 1476 1477 # Return standardized response format 1478 return { 1479 "content": assistant_content, 1480 "tool_calls": list(tool_calls_dict.values()) 1481 } 1482 1483 1484 async def _anthropic_runner( 1485 self, 1486 *, 1487 model, 1488 messages, 1489 stream, 1490 markdown=False, 1491 quiet=False, 1492 live=None, 1493 output_callback=None 1494 ): 1495 """Handle Anthropic-specific API interactions and response processing.""" 1496 # Log what we're sending for debugging 1497 self._log(f"[ANTHROPIC REQUEST] Sending request to {model} with {len(self.history)} messages", level="debug") 1498 1499 # Prepare API parameters - history is already normalized by _normalizer 1500 params = { 1501 "model": model, 1502 "messages": self.history, 1503 "max_tokens": 8192, 1504 "system": self.system 1505 } 1506 1507 # Add tools support if needed 1508 if self.tools_enabled and self.tools_supported: 1509 enabled_tools = [] 1510 for tool in self._get_enabled_tools(): 1511 # Extract parameters from OpenAI format 1512 tool_params = tool["function"]["parameters"] 1513 1514 # Create Anthropic-compatible tool definition 1515 format_tool = { 1516 "name": tool["function"]["name"], 1517 "description": tool["function"].get("description", ""), 1518 "input_schema": { 1519 "type": "object", 1520 "properties": tool_params.get("properties", {}) 1521 } 1522 } 1523 1524 # Ensure 'required' is at the correct level for Anthropic (as a direct child of input_schema) 1525 if "required" in tool_params: 1526 format_tool["input_schema"]["required"] = tool_params["required"] 1527 1528 enabled_tools.append(format_tool) 1529 1530 params["tools"] = enabled_tools 1531 1532 assistant_content = "" 1533 tool_calls_dict = {} 1534 1535 try: 1536 # Process streaming response 1537 if stream: 1538 stream_params = params.copy() 1539 stream_params["stream"] = True 1540 1541 stream_response = await self._retry_with_backoff( 1542 self.async_client.messages.create, 1543 **stream_params 1544 ) 1545 1546 content_type = None 1547 async for chunk in stream_response: 1548 chunk_type = getattr(chunk, "type", "unknown") 1549 self._log(f"[ANTHROPIC CHUNK] Type: {chunk_type}", level="debug") 1550 if chunk_type == "content_block_start" and hasattr(chunk.content_block, "type"): 1551 content_type = chunk.content_block.type 1552 if content_type == "tool_use": 1553 tool_id = chunk.content_block.id 1554 tool_name = chunk.content_block.name 1555 tool_input = chunk.content_block.input 1556 tool_calls_dict[tool_id] = { 1557 "id": tool_id, 1558 "function": { 1559 "name": tool_name, 1560 "arguments": "" 1561 } 1562 } 1563 self._log(f"[ANTHROPIC TOOL USE] {tool_name}", level="debug") 1564 1565 # Handle text content 1566 if chunk_type == "content_block_delta" and hasattr(chunk.delta, "text"): 1567 delta = chunk.delta.text 1568 assistant_content += delta 1569 if output_callback: 1570 output_callback(delta) 1571 elif live: 1572 live.update(Markdown(assistant_content)) 1573 elif not markdown: 1574 print(delta, end="") 1575 1576 # Handle complete tool use 1577 elif chunk_type == "content_block_delta" and content_type == "tool_use": 1578 tool_calls_dict[tool_id]["function"]["arguments"] += chunk.delta.partial_json 1579 1580 # Process non-streaming response 1581 else: 1582 # For non-streaming, ensure we don't send the stream parameter 1583 non_stream_params = params.copy() 1584 non_stream_params.pop("stream", None) # Remove stream if it exists 1585 1586 response = await self._retry_with_backoff( 1587 self.async_client.messages.create, 1588 **non_stream_params 1589 ) 1590 1591 # Extract text content 1592 for content_block in response.content: 1593 if content_block.type == "text": 1594 assistant_content += content_block.text 1595 1596 if content_block.type == "tool_use": 1597 tool_id = content_block.id 1598 tool_name = content_block.name 1599 tool_input = content_block.input 1600 tool_calls_dict[tool_id] = { 1601 "id": tool_id, 1602 "function": { 1603 "name": tool_name, 1604 "arguments": tool_input 1605 } 1606 } 1607 self._log(f"[ANTHROPIC TOOL USE] {tool_name}", level="debug") 1608 1609 if output_callback: 1610 output_callback(assistant_content) 1611 elif not quiet: 1612 print(assistant_content) 1613 1614 except Exception as e: 1615 self._log(f"[ANTHROPIC ERROR RUNNER] {traceback.format_exc()}", level="error") 1616 1617 # Return something usable even in case of error 1618 return { 1619 "content": f"Error processing Anthropic response: {str(e)}", 1620 "tool_calls": [] 1621 } 1622 1623 # Return standardized response format 1624 return { 1625 "content": assistant_content, 1626 "tool_calls": list(tool_calls_dict.values()) 1627 } 1628 1629 1630 def _get_enabled_tools(self) -> List[dict]: 1631 """Return the list of currently enabled tool function definitions.""" 1632 return [ 1633 tool for tool in self.tools 1634 if not tool["function"].get("disabled", False) 1635 ] 1636 1637 1638 async def _handle_tool_call_async( 1639 self, 1640 function_name: str, 1641 function_arguments: str, 1642 tool_call_id: str, 1643 quiet: bool = False, 1644 safe: bool = False, 1645 output_callback: Optional[Callable[[str], None]] = None 1646 ) -> str: 1647 """Process a tool call asynchronously and return the result. 1648 1649 Args: 1650 function_name: Name of the function to call. 1651 function_arguments: JSON string containing the function arguments. 1652 tool_call_id: Unique identifier for this tool call. 1653 params: Parameters used for the original API call. 1654 safe: If True, prompts for confirmation before executing the tool call. 1655 output_callback: Optional callback to handle the tool call result. 1656 1657 Returns: 1658 The result of the function call. 1659 1660 Raises: 1661 ValueError: If the function is not found or JSON is invalid. 1662 """ 1663 if isinstance(function_arguments, str): 1664 arguments = json.loads(function_arguments) 1665 else: 1666 arguments = function_arguments 1667 1668 self._log(f"[TOOL:{function_name}] args={arguments}") 1669 1670 func = getattr(self, function_name, None) 1671 if not func: 1672 raise ValueError(f"Function '{function_name}' not found.") 1673 1674 be_quiet = self.quiet if quiet is None else quiet 1675 1676 if not be_quiet: 1677 print(f"\nRunning {function_name}...") 1678 1679 if output_callback: 1680 notification = json.dumps({ 1681 "type": "tool_call", 1682 "tool_name": function_name, 1683 "status": "started" 1684 }) 1685 output_callback(notification) 1686 1687 try: 1688 if safe: 1689 prompt = f"[bold yellow]Proposed tool call:[/bold yellow] {function_name}({json.dumps(arguments, indent=2)})\n[bold cyan]Execute? [y/n]: [/bold cyan]" 1690 confirmed = Confirm.ask(prompt, default=False) 1691 if not confirmed: 1692 command_result = { 1693 "status": "cancelled", 1694 "message": "Tool call aborted by user" 1695 } 1696 print("[red]Tool call cancelled by user[/red]") 1697 else: 1698 loop = asyncio.get_event_loop() 1699 command_result = await loop.run_in_executor(None, lambda: func(**arguments)) 1700 else: 1701 loop = asyncio.get_event_loop() 1702 command_result = await loop.run_in_executor(None, lambda: func(**arguments)) 1703 1704 try: 1705 json.dumps(command_result) 1706 except TypeError as e: 1707 self.logger.error(f"Tool call result not serializable: {e}") 1708 return {"error": "Tool call returned unserializable data."} 1709 1710 if output_callback: 1711 notification = json.dumps({ 1712 "type": "tool_call", 1713 "tool_name": function_name, 1714 "status": "completed" 1715 }) 1716 output_callback(notification) 1717 1718 return command_result 1719 1720 except Exception as e: 1721 self._log(f"[ERROR] Tool execution failed: {e}", level="error") 1722 self.logger.error(f"Error executing tool function '{function_name}': {e}") 1723 return {"error": str(e)} 1724 1725 1726 def _setup_encoding(self): 1727 """Initialize the token encoding system for the current model. 1728 1729 This method sets up the appropriate tokenizer based on the current 1730 model provider. For OpenAI models, it attempts to use the model-specific 1731 tokenizer, falling back to cl100k_base if not available. For other 1732 providers, it uses cl100k_base as a default. 1733 1734 The encoding is used for token counting and context management. 1735 """ 1736 try: 1737 if self.provider == "openai": 1738 try: 1739 self.encoding = tiktoken.encoding_for_model(self.model) 1740 self._log(f"[ENCODING] Loaded tokenizer for OpenAI model: {self.model}") 1741 except: 1742 self.encoding = tiktoken.get_encoding("cl100k_base") 1743 self._log(f"[ENCODING] Fallback to cl100k_base for model: {self.model}") 1744 else: 1745 self.encoding = tiktoken.get_encoding("cl100k_base") 1746 self._log(f"[ENCODING] Defaulting to cl100k_base for non-OpenAI model: {self.model}") 1747 except Exception as e: 1748 self.logger.error(f"Failed to setup encoding: {e}") 1749 self.encoding = tiktoken.get_encoding("cl100k_base") 1750 1751 1752 def _estimate_tokens_tiktoken(self, messages) -> int: 1753 """Rough token count estimate using tiktoken for OpenAI or fallback cases.""" 1754 if not hasattr(self, "encoding") or not self.encoding: 1755 self._setup_encoding() 1756 return sum(len(self.encoding.encode(msg.get("content", ""))) for msg in messages if isinstance(msg.get("content"), str)) 1757 1758 1759 def _count_tokens(self, messages, use_cache=True) -> int: 1760 """Accurately estimate token count for messages including tool calls with caching support. 1761 1762 Args: 1763 messages: List of message objects in either OpenAI or Anthropic format. 1764 use_cache: Whether to use and update the token count cache. 1765 1766 Returns: 1767 int: Estimated token count. 1768 """ 1769 # Setup encoding if needed 1770 if not hasattr(self, "encoding") or not self.encoding: 1771 self._setup_encoding() 1772 1773 # Initialize cache if it doesn't exist 1774 if not hasattr(self, "_token_count_cache"): 1775 self._token_count_cache = {} 1776 1777 # Generate a cache key based on message content hashes 1778 if use_cache: 1779 try: 1780 # Create a cache key using message IDs if available, or content hashes 1781 cache_key_parts = [] 1782 for msg in messages: 1783 if isinstance(msg, dict): 1784 # Try to use stable identifiers for cache key 1785 msg_id = msg.get("id", None) 1786 timestamp = msg.get("timestamp", None) 1787 1788 if msg_id and timestamp: 1789 cache_key_parts.append(f"{msg_id}:{timestamp}") 1790 else: 1791 # Fall back to content-based hash if no stable IDs 1792 content_str = str(msg.get("content", "")) 1793 role = msg.get("role", "unknown") 1794 cache_key_parts.append(f"{role}:{hash(content_str)}") 1795 1796 cache_key = ":".join(cache_key_parts) 1797 if cache_key in self._token_count_cache: 1798 return self._token_count_cache[cache_key] 1799 except Exception as e: 1800 # If caching fails, just continue with normal counting 1801 self._log(f"[TOKEN COUNT] Cache key generation failed: {e}", level="debug") 1802 use_cache = False 1803 1804 # For Claude models, try to use their built-in token counter 1805 if self.sdk == "anthropic": 1806 try: 1807 # Convert messages to Anthropic format if needed 1808 anthropic_messages = [] 1809 for msg in messages: 1810 if msg.get("role") == "system": 1811 continue # System handled separately 1812 1813 if msg.get("role") == "tool": 1814 # Skip tool messages in token count to avoid double-counting 1815 continue 1816 1817 if msg.get("role") == "user" and isinstance(msg.get("content"), list): 1818 # Already in Anthropic format with tool_result 1819 anthropic_messages.append(msg) 1820 elif msg.get("role") in ["user", "assistant"]: 1821 if not msg.get("tool_calls") and not msg.get("tool_use"): 1822 # Simple message 1823 anthropic_messages.append({ 1824 "role": msg.get("role"), 1825 "content": msg.get("content", "") 1826 }) 1827 1828 # Use Anthropic's token counter if messages exist 1829 if anthropic_messages: 1830 response = self.client.messages.count_tokens( 1831 model=self.model, 1832 messages=anthropic_messages, 1833 system=self.system 1834 ) 1835 token_count = response.input_tokens 1836 1837 # Cache the result for future use 1838 if use_cache and 'cache_key' in locals(): 1839 self._token_count_cache[cache_key] = token_count 1840 1841 return token_count 1842 except Exception as e: 1843 # Fall back to our estimation 1844 self._log(f"[TOKEN COUNT] Error using Anthropic token counter: {e}", level="debug") 1845 1846 # More accurate token counting for all message types 1847 num_tokens = 0 1848 1849 # Count tokens for each message 1850 for msg in messages: 1851 # Base token count for message metadata (role + message format) 1852 num_tokens += 4 # Message overhead 1853 1854 # Add tokens for role name 1855 role = msg.get("role", "") 1856 num_tokens += len(self.encoding.encode(role)) 1857 1858 # Count tokens in message content 1859 if isinstance(msg.get("content"), str): 1860 content = msg.get("content", "") 1861 content_tokens = len(self.encoding.encode(content)) 1862 num_tokens += content_tokens 1863 1864 elif isinstance(msg.get("content"), list): 1865 # Handle Anthropic-style content lists 1866 for item in msg.get("content", []): 1867 if isinstance(item, dict): 1868 # Tool result or other structured content 1869 if item.get("type") == "tool_result": 1870 result_content = item.get("content", "") 1871 if isinstance(result_content, str): 1872 num_tokens += len(self.encoding.encode(result_content)) 1873 else: 1874 # JSON serialization for dict/list content 1875 num_tokens += len(self.encoding.encode(json.dumps(result_content))) 1876 # Add tokens for tool_use_id and type fields 1877 num_tokens += len(self.encoding.encode(item.get("type", ""))) 1878 num_tokens += len(self.encoding.encode(item.get("tool_use_id", ""))) 1879 1880 # Text content type 1881 elif item.get("type") == "text": 1882 num_tokens += len(self.encoding.encode(item.get("text", ""))) 1883 1884 # Tool use type 1885 elif item.get("type") == "tool_use": 1886 num_tokens += len(self.encoding.encode(item.get("name", ""))) 1887 tool_input = item.get("input", {}) 1888 if isinstance(tool_input, str): 1889 num_tokens += len(self.encoding.encode(tool_input)) 1890 else: 1891 num_tokens += len(self.encoding.encode(json.dumps(tool_input))) 1892 num_tokens += len(self.encoding.encode(item.get("id", ""))) 1893 else: 1894 # Plain text content 1895 num_tokens += len(self.encoding.encode(str(item))) 1896 1897 # Count tokens in tool calls for OpenAI format 1898 if msg.get("tool_calls"): 1899 for tool_call in msg.get("tool_calls", []): 1900 if isinstance(tool_call, dict): 1901 # Count tokens for function name 1902 func_name = tool_call.get("function", {}).get("name", "") 1903 num_tokens += len(self.encoding.encode(func_name)) 1904 1905 # Count tokens for arguments 1906 args = tool_call.get("function", {}).get("arguments", "") 1907 if isinstance(args, str): 1908 num_tokens += len(self.encoding.encode(args)) 1909 else: 1910 num_tokens += len(self.encoding.encode(json.dumps(args))) 1911 1912 # Add tokens for id and type fields 1913 num_tokens += len(self.encoding.encode(tool_call.get("id", ""))) 1914 num_tokens += len(self.encoding.encode(tool_call.get("type", "function"))) 1915 1916 # Count tokens in Anthropic tool_use field 1917 if msg.get("tool_use"): 1918 tool_use = msg.get("tool_use") 1919 # Count tokens for name 1920 num_tokens += len(self.encoding.encode(tool_use.get("name", ""))) 1921 1922 # Count tokens for input 1923 tool_input = tool_use.get("input", {}) 1924 if isinstance(tool_input, str): 1925 num_tokens += len(self.encoding.encode(tool_input)) 1926 else: 1927 num_tokens += len(self.encoding.encode(json.dumps(tool_input))) 1928 1929 # Add tokens for id field 1930 num_tokens += len(self.encoding.encode(tool_use.get("id", ""))) 1931 1932 # Handle tool response message format 1933 if msg.get("role") == "tool": 1934 # Add tokens for tool_call_id 1935 tool_id = msg.get("tool_call_id", "") 1936 num_tokens += len(self.encoding.encode(tool_id)) 1937 1938 # Add message end tokens 1939 num_tokens += 2 1940 1941 # Cache the result for future use 1942 if use_cache and 'cache_key' in locals(): 1943 self._token_count_cache[cache_key] = num_tokens 1944 1945 return num_tokens 1946 1947 1948 def _cycle_messages(self): 1949 """Intelligently trim the message history to fit within the allowed context length. 1950 1951 This method implements a sophisticated trimming strategy that: 1952 1. Always preserves system messages 1953 2. Always keeps the most recent complete conversation turn 1954 3. Prioritizes keeping tool call chains intact 1955 4. Preserves important context from earlier exchanges 1956 5. Aggressively prunes redundant information before essential content 1957 1958 The method uses a multi-pass approach: 1959 1. First pass: Identify critical messages that must be kept 1960 2. Second pass: Calculate token counts for each message 1961 3. Third pass: Keep important tool chains intact 1962 4. Fourth pass: Fill remaining space with recent messages 1963 1964 Returns: 1965 bool: True if all messages were trimmed (context exceeded), 1966 False if trimming was successful and context is within limits 1967 1968 Note: 1969 This method maintains both session_history and history in sync, 1970 ensuring proper SDK-specific formatting is preserved. 1971 """ 1972 # Check if we need to trim 1973 token_count = self._count_tokens(self.history) 1974 1975 # If we're already under the limit, return early 1976 if token_count <= self.context_length: 1977 return False 1978 1979 self._log(f"[TRIM] Starting message cycling: {token_count} tokens exceeds {self.context_length} limit", level="info") 1980 1981 # We'll need to track tokens as we reconstruct the history 1982 target_tokens = max(self.context_length * 0.8, self.context_length - 1000) # Target 80% or 1000 less than max 1983 1984 # First pass: identify critical messages we must keep 1985 must_keep = [] 1986 tool_chain_groups = {} # Group related tool calls and their results 1987 1988 # Always keep system messages (should be first) 1989 system_indices = [] 1990 for i, msg in enumerate(self.history): 1991 if msg.get("role") == "system": 1992 system_indices.append(i) 1993 must_keep.append(i) 1994 1995 # Identify the most recent complete exchange (user question + assistant response) 1996 latest_exchange = [] 1997 # Start from the end and work backward to find the last complete exchange 1998 for i in range(len(self.history) - 1, -1, -1): 1999 msg = self.history[i] 2000 if msg.get("role") == "assistant" and not latest_exchange: 2001 latest_exchange.append(i) 2002 elif msg.get("role") == "user" and latest_exchange: 2003 latest_exchange.append(i) 2004 break 2005 2006 # Add the latest exchange to must-keep 2007 must_keep.extend(latest_exchange) 2008 2009 # Identify tool chains - track which messages belong to the same tool flow 2010 tool_id_to_chain = {} 2011 for i, msg in enumerate(self.history): 2012 # For assistant messages with tool calls 2013 if msg.get("role") == "assistant" and msg.get("tool_calls"): 2014 for tool_call in msg.get("tool_calls"): 2015 tool_id = tool_call.get("id") 2016 if tool_id: 2017 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2018 2019 # For tool response messages 2020 elif msg.get("role") == "tool" and msg.get("tool_call_id"): 2021 tool_id = msg.get("tool_call_id") 2022 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2023 2024 # For Anthropic format with tool use 2025 elif msg.get("role") == "assistant" and isinstance(msg.get("content"), list): 2026 for block in msg.get("content", []): 2027 if isinstance(block, dict) and block.get("type") == "tool_use": 2028 tool_id = block.get("id") 2029 if tool_id: 2030 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2031 2032 # For Anthropic tool result messages 2033 elif msg.get("role") == "user" and isinstance(msg.get("content"), list): 2034 for block in msg.get("content", []): 2035 if isinstance(block, dict) and block.get("type") == "tool_result": 2036 tool_id = block.get("tool_use_id") 2037 if tool_id: 2038 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2039 2040 # Group together all indices for each tool chain 2041 for tool_id, indices in tool_id_to_chain.items(): 2042 chain_key = f"tool_{min(indices)}" # Group by the earliest message 2043 if chain_key not in tool_chain_groups: 2044 tool_chain_groups[chain_key] = set() 2045 tool_chain_groups[chain_key].update(indices) 2046 2047 # Second pass: calculate tokens for each message 2048 message_tokens = [] 2049 for i, msg in enumerate(self.history): 2050 # Count tokens for this individual message 2051 tokens = self._count_tokens([msg]) 2052 message_tokens.append((i, tokens)) 2053 2054 # Keep the messages identified as must-keep 2055 keep_indices = set(must_keep) 2056 2057 # Calculate the tokens we've committed to keeping 2058 keep_tokens = sum(tokens for i, tokens in message_tokens if i in keep_indices) 2059 2060 # Check if we've already exceeded the target with just must-keep messages 2061 if keep_tokens > self.context_length: 2062 # We're in trouble - the essential messages alone exceed context 2063 # Drop older messages until we're under the limit 2064 all_indices = sorted(keep_indices) 2065 2066 # Start dropping oldest messages, but NEVER drop system messages 2067 for idx in all_indices: 2068 if idx not in system_indices: 2069 keep_indices.remove(idx) 2070 keep_tokens -= message_tokens[idx][1] 2071 if keep_tokens <= target_tokens: 2072 break 2073 2074 # If we've removed everything but system messages and still over limit 2075 if keep_tokens > self.context_length: 2076 self._log(f"[TRIM] Critical failure: even with minimal context ({keep_tokens} tokens), we exceed the limit", level="error") 2077 # Keep only system messages if any 2078 keep_indices = set(system_indices) 2079 return True # Context exceeded completely 2080 2081 # Third pass: keep the most important tool chains intact 2082 available_tokens = target_tokens - keep_tokens 2083 # Sort tool chains by recency (assumed by the chain_key which uses the earliest message) 2084 sorted_chains = sorted(tool_chain_groups.items(), key=lambda x: x[0], reverse=True) 2085 2086 for chain_key, indices in sorted_chains: 2087 # Skip if we've already decided to keep all messages in this chain 2088 if indices.issubset(keep_indices): 2089 continue 2090 2091 # Calculate how many tokens this chain would add 2092 chain_tokens = sum(tokens for i, tokens in message_tokens if i in indices and i not in keep_indices) 2093 2094 # If we can fit the entire chain, keep it 2095 if chain_tokens <= available_tokens: 2096 keep_indices.update(indices) 2097 available_tokens -= chain_tokens 2098 # Otherwise, we might want to keep partial chains in the future, but for now, skip 2099 2100 # Fourth pass: fill in with as many remaining messages as possible, prioritizing recency 2101 # Get remaining messages sorted by recency (newest first) 2102 remaining_indices = [(i, tokens) for i, tokens in message_tokens if i not in keep_indices] 2103 remaining_indices.sort(reverse=True) # Sort newest first 2104 2105 for i, tokens in remaining_indices: 2106 if tokens <= available_tokens: 2107 keep_indices.add(i) 2108 available_tokens -= tokens 2109 2110 # Final message reconstruction 2111 self._log(f"[TRIM] Keeping {len(keep_indices)}/{len(self.history)} messages, estimated {target_tokens - available_tokens} tokens", level="info") 2112 2113 # Create new history with just the kept messages, preserving order 2114 new_history = [self.history[i] for i in sorted(keep_indices)] 2115 self.history = new_history 2116 2117 # Update session_history to match the pruned history 2118 if hasattr(self, "session_history"): 2119 # Map between history items and session_history 2120 session_to_keep = [] 2121 2122 # For each session history message, check if it corresponds to a kept message 2123 for session_msg in self.session_history: 2124 # Keep system messages 2125 if session_msg.get("role") == "system": 2126 session_to_keep.append(session_msg) 2127 continue 2128 2129 # Try to match based on available IDs or content 2130 msg_id = session_msg.get("id") 2131 2132 # For tool messages, check tool_info.id against tool_call_id 2133 if "metadata" in session_msg and "tool_info" in session_msg["metadata"]: 2134 tool_id = session_msg["metadata"]["tool_info"].get("id") 2135 2136 # Check if this tool_id is still in the kept history 2137 for history_msg in new_history: 2138 # Check standard ids 2139 history_tool_id = None 2140 2141 # Check OpenAI format 2142 if history_msg.get("role") == "tool": 2143 history_tool_id = history_msg.get("tool_call_id") 2144 elif history_msg.get("role") == "assistant" and history_msg.get("tool_calls"): 2145 for call in history_msg.get("tool_calls", []): 2146 if call.get("id") == tool_id: 2147 history_tool_id = call.get("id") 2148 break 2149 2150 # Check Anthropic format 2151 elif isinstance(history_msg.get("content"), list): 2152 for block in history_msg.get("content", []): 2153 if isinstance(block, dict): 2154 if block.get("type") == "tool_use" and block.get("id") == tool_id: 2155 history_tool_id = block.get("id") 2156 break 2157 elif block.get("type") == "tool_result" and block.get("tool_use_id") == tool_id: 2158 history_tool_id = block.get("tool_use_id") 2159 break 2160 2161 if history_tool_id == tool_id: 2162 session_to_keep.append(session_msg) 2163 break 2164 2165 # For regular messages, try content matching as fallback 2166 else: 2167 content_match = False 2168 if isinstance(session_msg.get("content"), str) and session_msg.get("content"): 2169 for history_msg in new_history: 2170 if history_msg.get("role") == session_msg.get("role") and history_msg.get("content") == session_msg.get("content"): 2171 content_match = True 2172 break 2173 2174 if content_match: 2175 session_to_keep.append(session_msg) 2176 2177 # Update session_history with kept messages 2178 self.session_history = session_to_keep 2179 2180 # Re-normalize to ensure consistency 2181 self._normalizer(force=True) 2182 2183 # Verify our final token count 2184 final_token_count = self._count_tokens(self.history) 2185 self._log(f"[TRIM] Final history has {len(self.history)} messages, {final_token_count} tokens", level="info") 2186 2187 # Return whether we've completely exceeded context 2188 return final_token_count > self.context_length or len(self.history) == len(system_indices) 2189 2190 2191 def messages_add( 2192 self, 2193 role: str, 2194 content: Any, 2195 tool_info: Optional[Dict] = None, 2196 normalize: bool = True 2197 ) -> str: 2198 """ 2199 Add a message to the standardized session_history and then update SDK-specific history. 2200 2201 This method is the central point for all message additions to the conversation. 2202 2203 Args: 2204 role: The role of the message ("user", "assistant", "system", "tool") 2205 content: The message content (text or structured) 2206 tool_info: Optional tool-related metadata 2207 normalize: Whether to normalize history after adding this message 2208 2209 Returns: 2210 str: Unique ID of the added message 2211 """ 2212 # Generate a unique message ID 2213 message_id = str(uuid.uuid4()) 2214 2215 # Create the standardized message for session_history 2216 timestamp = datetime.now(timezone.utc).isoformat() 2217 2218 # Store system messages directly 2219 if role == "system": 2220 self.system = content 2221 2222 # Create standard format message 2223 standard_message = { 2224 "role": role, 2225 "content": content, 2226 "id": message_id, 2227 "timestamp": timestamp, 2228 "metadata": { 2229 "sdk": self.sdk 2230 } 2231 } 2232 2233 # Add tool info if provided 2234 if tool_info: 2235 standard_message["metadata"]["tool_info"] = tool_info 2236 2237 # Add to session_history 2238 if not hasattr(self, "session_history"): 2239 self.session_history = [] 2240 2241 self.session_history.append(standard_message) 2242 2243 # Save to persistent session if enabled 2244 if self.session_enabled and self.session_id: 2245 # Convert standard message to session-compatible format 2246 session_msg = { 2247 "role": role, 2248 "content": content, 2249 "id": message_id, 2250 "timestamp": timestamp 2251 } 2252 2253 # Add tool-related fields if present 2254 if tool_info: 2255 for key, value in tool_info.items(): 2256 session_msg[key] = value 2257 2258 # Store in session 2259 self.session.msg_insert(self.session_id, session_msg) 2260 2261 # Update the SDK-specific format in self.history by running the normalizer 2262 if normalize: 2263 # We only need to normalize the most recent message for efficiency 2264 # Pass a flag indicating we're just normalizing a new message 2265 self._normalizer(force=False, new_message_only=True) 2266 2267 # Log the added message 2268 self._log(f"[MESSAGE ADDED] {role}: {str(content)[:50]}...") 2269 2270 return message_id 2271 2272 2273 def messages_system(self, prompt: str): 2274 """Set or retrieve the current system prompt. 2275 2276 This method manages the system prompt that guides the AI's behavior. 2277 It can be used to both set a new system prompt and retrieve the current one. 2278 When setting a new prompt, it updates the system message in the conversation 2279 history and persists it to the session if enabled. 2280 2281 Args: 2282 prompt (str): The new system prompt to set. If empty or None, 2283 returns the current system prompt without changes. 2284 2285 Returns: 2286 str: The current system prompt after any updates 2287 """ 2288 if not isinstance(prompt, str) or not prompt: 2289 return self.system 2290 2291 # If the prompt hasn't changed, don't do anything 2292 if self.system == prompt: 2293 return self.system 2294 2295 # Update the system prompt 2296 old_system = self.system 2297 self.system = prompt 2298 2299 # For OpenAI, update or insert the system message in history 2300 if self.sdk == "openai": 2301 # Check if there's already a system message 2302 system_index = next((i for i, msg in enumerate(self.history) 2303 if msg.get("role") == "system"), None) 2304 2305 if system_index is not None: 2306 # Update existing system message 2307 self.history[system_index]["content"] = prompt 2308 else: 2309 # Insert new system message at the beginning 2310 self.history.insert(0, {"role": "system", "content": prompt}) 2311 2312 # For Anthropic, system message is not part of history, just save it for API calls 2313 2314 # Log to session only if prompt actually changed 2315 if self.session_enabled and self.session_id and old_system != prompt: 2316 self.session.msg_insert(self.session_id, {"role": "system", "content": prompt}) 2317 2318 return self.system 2319 2320 2321 def messages(self) -> list: 2322 """Return full session messages (persisted or in-memory).""" 2323 if self.session_enabled and self.session_id: 2324 return self.session.load_full(self.session_id).get("messages", []) 2325 return self.session_history 2326 2327 2328 def messages_length(self) -> int: 2329 """Calculate the total token count for the message history.""" 2330 if not self.encoding: 2331 return 0 2332 2333 total_tokens = 0 2334 for message in self.history: 2335 if message.get("content"): 2336 total_tokens += len(self.encoding.encode(message["content"])) 2337 if message.get("tool_calls"): 2338 for tool_call in message["tool_calls"]: 2339 if tool_call.get("function"): 2340 total_tokens += len(self.encoding.encode(tool_call["function"].get("name", ""))) 2341 total_tokens += len(self.encoding.encode(tool_call["function"].get("arguments", ""))) 2342 return total_tokens 2343 2344 2345 def session_load(self, session_id: Optional[str]): 2346 """Load and normalize messages for a specific session. 2347 2348 This method loads a conversation session from persistent storage and 2349 normalizes the messages to the current SDK format. It handles system 2350 messages, tool calls, and maintains message ordering. If loading fails, 2351 it resets to an empty session with the default system prompt. 2352 2353 Args: 2354 session_id (Optional[str]): The ID of the session to load. 2355 If None, resets to in-memory mode. 2356 2357 Note: 2358 This method will update both session_history and history to match 2359 the loaded session's state. It also ensures proper SDK-specific 2360 message formatting. 2361 """ 2362 self.session_id = session_id 2363 self._last_session_id = session_id 2364 2365 if self.session_enabled and session_id: 2366 try: 2367 # Load raw session data 2368 session_data = self.session.load_full(session_id) 2369 messages = session_data.get("messages", []) 2370 2371 # Convert session format to our standard format 2372 self.session_history = [] 2373 2374 # Track the most recent system message 2375 latest_system_msg = None 2376 2377 for msg in messages: 2378 # Extract fields 2379 role = msg.get("role", "user") 2380 content = msg.get("content", "") 2381 msg_id = msg.get("id", str(uuid.uuid4())) 2382 timestamp = msg.get("timestamp", datetime.now(timezone.utc).isoformat()) 2383 2384 # If this is a system message, track it but don't add to session_history yet 2385 if role == "system": 2386 if latest_system_msg is None or timestamp > latest_system_msg["timestamp"]: 2387 latest_system_msg = { 2388 "role": role, 2389 "content": content, 2390 "id": msg_id, 2391 "timestamp": timestamp, 2392 "metadata": {"sdk": self.sdk} 2393 } 2394 continue 2395 2396 # Build tool_info if present 2397 tool_info = None 2398 if any(key in msg for key in ["tool_use_id", "tool_call_id", "name", "arguments"]): 2399 tool_info = { 2400 "id": msg.get("tool_use_id") or msg.get("tool_call_id"), 2401 "name": msg.get("name", "unknown_tool"), 2402 "arguments": msg.get("arguments", {}) 2403 } 2404 2405 # Create standard message 2406 standard_msg = { 2407 "role": role, 2408 "content": content, 2409 "id": msg_id, 2410 "timestamp": timestamp, 2411 "metadata": { 2412 "sdk": self.sdk 2413 } 2414 } 2415 2416 if tool_info: 2417 standard_msg["metadata"]["tool_info"] = tool_info 2418 2419 self.session_history.append(standard_msg) 2420 2421 # If we found a system message, update the system property and add to history 2422 if latest_system_msg: 2423 self.system = latest_system_msg["content"] 2424 # Insert at the beginning of session_history 2425 self.session_history.insert(0, latest_system_msg) 2426 else: 2427 # If no system message was found, add the current system message 2428 self.messages_add(role="system", content=self.system) 2429 2430 # Normalize to current SDK format 2431 self._normalizer(force=True) 2432 2433 self._log(f"[SESSION] Switched to session '{session_id}'") 2434 except Exception as e: 2435 self.logger.error(f"Failed to load session '{session_id}': {e}") 2436 self.session_reset() 2437 else: 2438 # Reset to empty state with system message 2439 self.session_reset() 2440 2441 2442 def session_reset(self): 2443 """Reset the current session state and reinitialize to default system prompt. 2444 2445 This method performs a complete reset of the conversation state: 2446 1. Clears all message history 2447 2. Disables session ID tracking 2448 3. Returns to in-memory mode 2449 4. Reinitializes with the default system prompt 2450 2451 The reset is useful for starting fresh conversations or recovering 2452 from error states. It maintains the basic system configuration while 2453 clearing all conversation context. 2454 """ 2455 self.session_id = None 2456 self._last_session_id = None 2457 2458 # Clear histories 2459 self.session_history = [] 2460 self.history = [] 2461 2462 # Reapply the system message 2463 if hasattr(self, "system") and self.system: 2464 # Add to session_history 2465 self.messages_add(role="system", content=self.system) 2466 else: 2467 # Ensure we have a default system message 2468 self.system = "You are a helpful Assistant." 2469 self.messages_add(role="system", content=self.system) 2470 2471 self._log("[SESSION] Reset to in-memory mode") 2472 2473 2474 def _normalizer(self, force=False, new_message_only=False): 2475 """Central normalization function for message format conversion. 2476 2477 This method transforms the standardized session_history into the 2478 SDK-specific format needed in self.history. It handles different 2479 message types (system, user, assistant, tool) and their various 2480 formats across different SDKs. 2481 2482 Args: 2483 force (bool): If True, always normalize even if SDK hasn't changed. 2484 Default is False, which only normalizes on SDK change. 2485 new_message_only (bool): If True, only normalize the most recent message 2486 for efficiency when adding single messages. 2487 2488 Note: 2489 This method is the central point for message format conversion and 2490 ensures consistency between session storage and API communication. 2491 """ 2492 # Skip normalization if SDK hasn't changed and force is False 2493 if not force and hasattr(self, '_last_sdk') and self._last_sdk == self.sdk: 2494 # If we only need to normalize the most recent message 2495 if new_message_only and self.session_history: 2496 # Get the most recent message from session_history 2497 recent_msg = self.session_history[-1] 2498 2499 # Apply SDK-specific normalization for just this message 2500 if self.sdk == "openai": 2501 self._openai_normalize_message(recent_msg) 2502 elif self.sdk == "anthropic": 2503 self._anthropic_normalize_message(recent_msg) 2504 else: 2505 # Generic handler for unknown SDKs 2506 self._generic_normalize_message(recent_msg) 2507 2508 return 2509 2510 # Record the current SDK to detect future changes 2511 self._last_sdk = self.sdk 2512 2513 # For full normalization, clear current history and rebuild it 2514 self.history = [] 2515 2516 # Call the appropriate SDK-specific normalizer 2517 if self.sdk == "openai": 2518 self._openai_normalizer() 2519 elif self.sdk == "anthropic": 2520 self._anthropic_normalizer() 2521 else: 2522 self.logger.warning(f"No normalizer available for SDK: {self.sdk}") 2523 # Fallback to a simple conversion for unknown SDKs 2524 for msg in self.session_history: 2525 self._generic_normalize_message(msg) 2526 2527 2528 def _openai_normalizer(self): 2529 """Convert standardized session_history to OpenAI-compatible format. 2530 2531 This method transforms the internal message format into the structure 2532 required by the OpenAI API. It handles: 2533 - System messages at the start of history 2534 - User messages with plain text 2535 - Assistant messages with optional tool calls 2536 - Tool response messages with tool_call_id 2537 2538 The resulting format matches OpenAI's chat completion API requirements 2539 for both regular messages and function calling. 2540 """ 2541 # For OpenAI, we need to include system message in the history 2542 # and convert tool calls/results to OpenAI format 2543 2544 # Start with empty history 2545 self.history = [] 2546 2547 # First, add the current system message at position 0 2548 self.history.append({ 2549 "role": "system", 2550 "content": self.system 2551 }) 2552 2553 # Process all non-system messages 2554 for msg in self.session_history: 2555 if msg["role"] == "system": 2556 continue # Skip system messages, already handled 2557 2558 # Handle different message types 2559 if msg["role"] == "user": 2560 # User messages are straightforward 2561 self.history.append({ 2562 "role": "user", 2563 "content": msg["content"] 2564 }) 2565 2566 elif msg["role"] == "assistant": 2567 # For assistant messages with tool calls 2568 if "metadata" in msg and msg["metadata"].get("tool_info"): 2569 # This is an assistant message with tool calls 2570 tool_info = msg["metadata"]["tool_info"] 2571 2572 # Create OpenAI assistant message with tool calls 2573 assistant_msg = { 2574 "role": "assistant", 2575 "content": msg["content"] if isinstance(msg["content"], str) else "", 2576 "tool_calls": [{ 2577 "id": tool_info["id"], 2578 "type": "function", 2579 "function": { 2580 "name": tool_info["name"], 2581 "arguments": json.dumps(tool_info["arguments"]) if isinstance(tool_info["arguments"], dict) else tool_info["arguments"] 2582 } 2583 }] 2584 } 2585 self.history.append(assistant_msg) 2586 else: 2587 # Regular assistant message 2588 self.history.append({ 2589 "role": "assistant", 2590 "content": msg["content"] 2591 }) 2592 2593 elif msg["role"] == "tool": 2594 # Tool response messages 2595 if "metadata" in msg and "tool_info" in msg["metadata"]: 2596 tool_msg = { 2597 "role": "tool", 2598 "tool_call_id": msg["metadata"]["tool_info"]["id"], 2599 "content": json.dumps(msg["content"]) if isinstance(msg["content"], (dict, list)) else msg["content"] 2600 } 2601 self.history.append(tool_msg) 2602 2603 2604 def _anthropic_normalizer(self): 2605 """Convert standardized session_history to Anthropic-compatible format. 2606 2607 This method transforms the internal message format into the structure 2608 required by the Anthropic API. It handles: 2609 - System messages (stored separately, not in history) 2610 - User messages with optional tool results 2611 - Assistant messages with optional tool use 2612 - Content blocks for structured responses 2613 2614 The resulting format matches Anthropic's message API requirements 2615 for both regular messages and tool use. 2616 """ 2617 # For Anthropic, we don't include system message in the history 2618 # but need to handle content blocks for tool use/results 2619 2620 # Start with empty history 2621 self.history = [] 2622 2623 # Process all non-system messages 2624 for msg in self.session_history: 2625 if msg["role"] == "system": 2626 # Update system prompt if this is the most recent system message 2627 # (only apply the most recent system message if we have multiple) 2628 if msg == self.session_history[-1] or all(m["role"] != "system" for m in self.session_history[self.session_history.index(msg)+1:]): 2629 self.system = msg["content"] 2630 continue # Skip system messages in history 2631 2632 # Handle different message types 2633 if msg["role"] == "user": 2634 # User messages - check if it contains tool results 2635 if "metadata" in msg and "tool_info" in msg["metadata"] and msg["metadata"]["tool_info"].get("result"): 2636 # This is a tool result message 2637 tool_info = msg["metadata"]["tool_info"] 2638 2639 # Create Anthropic tool result format 2640 tool_result_msg = { 2641 "role": "user", 2642 "content": [{ 2643 "type": "tool_result", 2644 "tool_use_id": tool_info["id"], 2645 "content": json.dumps(tool_info["result"]) if isinstance(tool_info["result"], (dict, list)) else str(tool_info["result"]) 2646 }] 2647 } 2648 self.history.append(tool_result_msg) 2649 else: 2650 # Regular user message 2651 self.history.append({ 2652 "role": "user", 2653 "content": msg["content"] 2654 }) 2655 2656 elif msg["role"] == "assistant": 2657 # For assistant messages, check for tool use 2658 if "metadata" in msg and "tool_info" in msg["metadata"]: 2659 # This is an assistant message with tool use 2660 tool_info = msg["metadata"]["tool_info"] 2661 2662 # Build content blocks 2663 content_blocks = [] 2664 2665 # Add text content if present 2666 if msg["content"]: 2667 content_blocks.append({ 2668 "type": "text", 2669 "text": msg["content"] if isinstance(msg["content"], str) else "" 2670 }) 2671 2672 # Add tool use block 2673 content_blocks.append({ 2674 "type": "tool_use", 2675 "id": tool_info["id"], 2676 "name": tool_info["name"], 2677 "input": tool_info["arguments"] if isinstance(tool_info["arguments"], dict) else json.loads(tool_info["arguments"]) 2678 }) 2679 2680 # Create Anthropic assistant message with tool use 2681 self.history.append({ 2682 "role": "assistant", 2683 "content": content_blocks 2684 }) 2685 else: 2686 # Regular assistant message 2687 self.history.append({ 2688 "role": "assistant", 2689 "content": msg["content"] 2690 }) 2691 2692 elif msg["role"] == "tool": 2693 # Tool messages in standard format get converted to user messages with tool_result 2694 if "metadata" in msg and "tool_info" in msg["metadata"]: 2695 tool_info = msg["metadata"]["tool_info"] 2696 2697 # Create Anthropic tool result message 2698 tool_result_msg = { 2699 "role": "user", 2700 "content": [{ 2701 "type": "tool_result", 2702 "tool_use_id": tool_info["id"], 2703 "content": json.dumps(msg["content"]) if isinstance(msg["content"], (dict, list)) else str(msg["content"]) 2704 }] 2705 } 2706 self.history.append(tool_result_msg) 2707 2708 2709 def _openai_normalize_message(self, msg): 2710 """Normalize a single message to OpenAI format and add to history.""" 2711 role = msg.get("role") 2712 content = msg.get("content") 2713 2714 if role == "system": 2715 # Check if we already have a system message in history 2716 system_index = next((i for i, m in enumerate(self.history) 2717 if m.get("role") == "system"), None) 2718 if system_index is not None: 2719 # Update existing system message 2720 self.history[system_index]["content"] = content 2721 else: 2722 # Insert new system message at the beginning 2723 self.history.insert(0, { 2724 "role": "system", 2725 "content": content 2726 }) 2727 # Update the system property 2728 self.system = content 2729 2730 elif role == "user": 2731 self.history.append({ 2732 "role": "user", 2733 "content": content 2734 }) 2735 2736 elif role == "assistant": 2737 # For assistant messages, handle potential tool calls 2738 if "metadata" in msg and msg["metadata"].get("tool_info"): 2739 # This is an assistant message with tool calls 2740 tool_info = msg["metadata"]["tool_info"] 2741 2742 # Create OpenAI assistant message with tool calls 2743 try: 2744 arguments = tool_info.get("arguments", {}) 2745 arguments_str = json.dumps(arguments) if isinstance(arguments, dict) else arguments 2746 except: 2747 arguments_str = str(arguments) 2748 2749 assistant_msg = { 2750 "role": "assistant", 2751 "content": content if isinstance(content, str) else "", 2752 "tool_calls": [{ 2753 "id": tool_info["id"], 2754 "type": "function", 2755 "function": { 2756 "name": tool_info["name"], 2757 "arguments": arguments_str 2758 } 2759 }] 2760 } 2761 self.history.append(assistant_msg) 2762 else: 2763 # Regular assistant message 2764 self.history.append({ 2765 "role": "assistant", 2766 "content": content 2767 }) 2768 2769 elif role == "tool": 2770 # Tool response messages 2771 if "metadata" in msg and "tool_info" in msg["metadata"]: 2772 tool_info = msg["metadata"]["tool_info"] 2773 tool_msg = { 2774 "role": "tool", 2775 "tool_call_id": tool_info["id"], 2776 "content": json.dumps(content) if isinstance(content, (dict, list)) else str(content) 2777 } 2778 self.history.append(tool_msg) 2779 2780 2781 def _anthropic_normalize_message(self, msg): 2782 """Normalize a single message to Anthropic format and add to history.""" 2783 role = msg.get("role") 2784 content = msg.get("content") 2785 2786 if role == "system": 2787 # Store system prompt separately, not in history for Anthropic 2788 self.system = content 2789 2790 elif role == "user": 2791 # User messages - check if it contains tool results 2792 if "metadata" in msg and "tool_info" in msg["metadata"]: 2793 tool_info = msg["metadata"]["tool_info"] 2794 # Check for result or directly use content 2795 result_content = tool_info.get("result", content) 2796 2797 # Create Anthropic tool result format 2798 try: 2799 result_str = json.dumps(result_content) if isinstance(result_content, (dict, list)) else str(result_content) 2800 except: 2801 result_str = str(result_content) 2802 2803 tool_result_msg = { 2804 "role": "user", 2805 "content": [{ 2806 "type": "tool_result", 2807 "tool_use_id": tool_info["id"], 2808 "content": result_str 2809 }] 2810 } 2811 self.history.append(tool_result_msg) 2812 else: 2813 # Regular user message 2814 self.history.append({ 2815 "role": "user", 2816 "content": content 2817 }) 2818 2819 elif role == "assistant": 2820 # For assistant messages, check for tool use 2821 if "metadata" in msg and "tool_info" in msg["metadata"]: 2822 # This is an assistant message with tool use 2823 tool_info = msg["metadata"]["tool_info"] 2824 2825 # Build content blocks 2826 content_blocks = [] 2827 2828 # Add text content if present 2829 if content: 2830 content_blocks.append({ 2831 "type": "text", 2832 "text": content if isinstance(content, str) else "" 2833 }) 2834 2835 # Add tool use block - safely convert arguments 2836 try: 2837 # Parse arguments to ensure it's a dictionary 2838 if isinstance(tool_info["arguments"], str): 2839 try: 2840 args_dict = json.loads(tool_info["arguments"]) 2841 except json.JSONDecodeError: 2842 args_dict = {"text": tool_info["arguments"]} 2843 else: 2844 args_dict = tool_info["arguments"] 2845 except: 2846 args_dict = {"error": "Failed to parse arguments"} 2847 2848 content_blocks.append({ 2849 "type": "tool_use", 2850 "id": tool_info["id"], 2851 "name": tool_info["name"], 2852 "input": args_dict 2853 }) 2854 2855 # Create Anthropic assistant message with tool use 2856 self.history.append({ 2857 "role": "assistant", 2858 "content": content_blocks 2859 }) 2860 else: 2861 # Regular assistant message 2862 self.history.append({ 2863 "role": "assistant", 2864 "content": content 2865 }) 2866 2867 elif role == "tool": 2868 # Tool messages in standard format get converted to user messages with tool_result 2869 if "metadata" in msg and "tool_info" in msg["metadata"]: 2870 tool_info = msg["metadata"]["tool_info"] 2871 2872 try: 2873 result_str = json.dumps(content) if isinstance(content, (dict, list)) else str(content) 2874 except: 2875 result_str = str(content) 2876 2877 # Create Anthropic tool result message 2878 tool_result_msg = { 2879 "role": "user", 2880 "content": [{ 2881 "type": "tool_result", 2882 "tool_use_id": tool_info["id"], 2883 "content": result_str 2884 }] 2885 } 2886 self.history.append(tool_result_msg) 2887 2888 2889 def _generic_normalize_message(self, msg): 2890 """Generic normalizer for unknown SDKs. 2891 2892 This method provides a basic message normalization for SDKs that 2893 don't have specific handling. It performs minimal conversion to 2894 ensure basic message structure is maintained. 2895 2896 Args: 2897 msg (dict): The message to normalize, containing at minimum: 2898 - role: Message role (user/assistant/system) 2899 - content: Message content 2900 2901 Note: 2902 This is a fallback method and should be overridden for specific 2903 SDK implementations when possible. 2904 """ 2905 role = msg.get("role") 2906 content = msg.get("content") 2907 2908 if role in ["user", "assistant", "system"]: 2909 self.history.append({ 2910 "role": role, 2911 "content": content 2912 }) 2913 2914 2915 def track_token_usage(self): 2916 """Track and return token usage across the conversation history. 2917 2918 This method maintains a history of token usage measurements and provides 2919 current usage statistics. It tracks: 2920 - Current token count 2921 - Context length limit 2922 - Usage percentage 2923 - Historical measurements 2924 - Current provider and model info 2925 2926 Returns: 2927 dict: Dictionary containing: 2928 - current: Current token count 2929 - limit: Maximum context length 2930 - percentage: Usage as percentage of limit 2931 - history: Last 10 measurements with timestamps 2932 - provider: Current provider name 2933 - model: Current model name 2934 2935 Note: 2936 The history is limited to the last 100 measurements to prevent 2937 unbounded memory growth. 2938 """ 2939 if not hasattr(self, "_token_history"): 2940 self._token_history = [] 2941 2942 # Count current tokens 2943 current_count = self._count_tokens(self.history) 2944 2945 # Add to history 2946 timestamp = datetime.now(timezone.utc).isoformat() 2947 self._token_history.append({ 2948 "timestamp": timestamp, 2949 "count": current_count, 2950 "limit": self.context_length, 2951 "provider": self.provider, 2952 "model": self.model 2953 }) 2954 2955 # Keep only the last 100 measurements to avoid unlimited growth 2956 if len(self._token_history) > 100: 2957 self._token_history = self._token_history[-100:] 2958 2959 # Return current tracking info 2960 return { 2961 "current": current_count, 2962 "limit": self.context_length, 2963 "percentage": round((current_count / self.context_length) * 100, 1) if self.context_length else 0, 2964 "history": self._token_history[-10:], # Return last 10 measurements 2965 "provider": self.provider, 2966 "model": self.model 2967 } 2968 2969 2970 def get_message_token_breakdown(self): 2971 """Analyze token usage by message type and provide a detailed breakdown. 2972 2973 This method performs a detailed analysis of token usage across the 2974 conversation history, breaking down usage by: 2975 - Message role (system, user, assistant, tool) 2976 - Content type (text, tool calls, tool results) 2977 - Individual message statistics 2978 2979 Returns: 2980 dict: Token usage breakdown containing: 2981 - total: Total tokens used 2982 - by_role: Tokens used by each role 2983 - by_type: Tokens used by content type 2984 - messages: List of individual message stats including: 2985 - index: Message position 2986 - role: Message role 2987 - tokens: Tokens used 2988 - has_tools: Whether message contains tool calls 2989 2990 Note: 2991 This analysis is useful for understanding token usage patterns 2992 and optimizing conversation context. 2993 """ 2994 breakdown = { 2995 "total": 0, 2996 "by_role": { 2997 "system": 0, 2998 "user": 0, 2999 "assistant": 0, 3000 "tool": 0 3001 }, 3002 "by_type": { 3003 "text": 0, 3004 "tool_calls": 0, 3005 "tool_results": 0 3006 }, 3007 "messages": [] 3008 } 3009 3010 # Analyze each message 3011 for i, msg in enumerate(self.history): 3012 msg_tokens = self._count_tokens([msg]) 3013 role = msg.get("role", "unknown") 3014 3015 # Track by role 3016 if role in breakdown["by_role"]: 3017 breakdown["by_role"][role] += msg_tokens 3018 3019 # Track by content type 3020 if role == "assistant" and msg.get("tool_calls"): 3021 breakdown["by_type"]["tool_calls"] += msg_tokens 3022 elif role == "tool": 3023 breakdown["by_type"]["tool_results"] += msg_tokens 3024 else: 3025 breakdown["by_type"]["text"] += msg_tokens 3026 3027 # Add individual message data 3028 breakdown["messages"].append({ 3029 "index": i, 3030 "role": role, 3031 "tokens": msg_tokens, 3032 "has_tools": bool(msg.get("tool_calls") or msg.get("tool_use") or 3033 (isinstance(msg.get("content"), list) and 3034 any(isinstance(c, dict) and c.get("type") in ["tool_use", "tool_result"] 3035 for c in msg.get("content", [])))) 3036 }) 3037 3038 # Update tota? 3039 breakdown["total"] += msg_tokens 3040 3041 return breakdown
65 def __init__( 66 self, 67 base_url: Optional[str] = None, 68 api_key: Optional[str] = None, 69 model: str = "openai:gpt-4o-mini", 70 fallback_model = "ollama:mistral-nemo:latest", 71 tools: Optional[bool] = True, 72 stream: bool = True, 73 quiet: bool = False, 74 context_length: int = 128000, 75 max_retries: int = 3, 76 retry_delay: float = 1.0, 77 log_path: Optional[str] = None, 78 raw: Optional[bool] = False, 79 session_enabled: bool = False, 80 session_id: Optional[str] = None, 81 session_path: Optional[str] = None 82 ): 83 """Initialize the universal AI interaction client. 84 85 Args: 86 base_url: Optional base URL for the API. If None, uses the provider's default URL. 87 api_key: Optional API key. If None, attempts to use environment variables based on provider. 88 model: Model identifier in format "provider:model_name" (e.g., "openai:gpt-4o-mini"). 89 tools: Enable (True) or disable (False) tool calling; None for auto-detection based on model support. 90 stream: Enable (True) or disable (False) streaming responses. 91 context_length: Maximum number of tokens to maintain in conversation history. 92 max_retries: Maximum number of retries for failed API calls. 93 retry_delay: Initial delay (in seconds) for exponential backoff retries. 94 session_enabled: Enable persistent session support. 95 session_id: Optional session ID to load messages from. 96 97 Raises: 98 ValueError: If provider is not supported or API key is missing for non-Ollama providers. 99 """ 100 self.system = "You are a helpful Assistant." 101 self.raw = raw 102 self.quiet = quiet 103 self.logger = logging.getLogger(f"InteractorLogger_{id(self)}") 104 self.logger.setLevel(logging.DEBUG) 105 self.providers = { 106 "openai": { 107 "sdk": "openai", 108 "base_url": "https://api.openai.com/v1", 109 "api_key": api_key or os.getenv("OPENAI_API_KEY") or None 110 }, 111 "ollama": { 112 "sdk": "openai", 113 "base_url": "http://localhost:11434/v1", 114 "api_key": api_key or "ollama" 115 }, 116 "nvidia": { 117 "sdk": "openai", 118 "base_url": "https://integrate.api.nvidia.com/v1", 119 "api_key": api_key or os.getenv("NVIDIA_API_KEY") or None 120 }, 121 "google": { 122 "sdk": "openai", 123 "base_url": "https://generativelanguage.googleapis.com/v1beta/openai", 124 "api_key": api_key or os.getenv("GEMINI_API_KEY") or None 125 }, 126 "anthropic": { 127 "sdk": "anthropic", 128 "base_url": "https://api.anthropic.com/v1", 129 "api_key": api_key or os.getenv("ANTHROPIC_API_KEY") or None 130 }, 131 "mistral": { 132 "sdk": "openai", 133 "base_url": "https://api.mistral.ai/v1", 134 "api_key": api_key or os.getenv("MISTRAL_API_KEY") or None 135 }, 136 "deepseek": { 137 "sdk": "openai", 138 "base_url": "https://api.deepseek.com", 139 "api_key": api_key or os.getenv("DEEPSEEK_API_KEY") or None 140 }, 141 } 142 """ 143 "grok": { 144 "sdk": "grok", 145 "base_url": "https://api.x.ai/v1", 146 "api_key": api_key or os.getenv("GROK_API_KEY") or None 147 } 148 } 149 """ 150 151 # Console log handler (always enabled at WARNING+) 152 if not self.logger.handlers: 153 console_handler = logging.StreamHandler(sys.stdout) 154 console_handler.setLevel(logging.WARNING) 155 console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) 156 self.logger.addHandler(console_handler) 157 158 self._log_enabled = False 159 if log_path: 160 file_handler = logging.FileHandler(log_path) 161 file_handler.setLevel(logging.DEBUG) 162 file_handler.setFormatter(logging.Formatter( 163 "%(asctime)s - %(levelname)s - %(message)s", 164 datefmt="%Y-%m-%d %H:%M:%S" 165 )) 166 self.logger.addHandler(file_handler) 167 self._log_enabled = True 168 169 170 self.token_estimate = 0 171 self.last_token_estimate = 0 172 self.stream = stream 173 self.tools = [] 174 self.session_history = [] 175 self.history = [] 176 self.context_length = context_length 177 self.encoding = None 178 self.max_retries = max_retries 179 self.retry_delay = retry_delay 180 self.reveal_tool = [] 181 self.fallback_model = fallback_model 182 self.sdk = None 183 184 # Session support 185 self.session_enabled = session_enabled 186 self.session_id = session_id 187 self._last_session_id = session_id 188 self.session = Session(directory=session_path) if session_enabled else None 189 190 191 if model is None: 192 model = "openai:gpt-4o-mini" 193 194 # Initialize model + encoding 195 self._setup_client(model, base_url, api_key) 196 self.tools_enabled = self.tools_supported if tools is None else tools and self.tools_supported 197 self._setup_encoding() 198 self.messages_add(role="system", content=self.system) 199 200 # Populate list of available models 201 self.models = self._update_available_models()
Initialize the universal AI interaction client.
Args: base_url: Optional base URL for the API. If None, uses the provider's default URL. api_key: Optional API key. If None, attempts to use environment variables based on provider. model: Model identifier in format "provider:model_name" (e.g., "openai:gpt-4o-mini"). tools: Enable (True) or disable (False) tool calling; None for auto-detection based on model support. stream: Enable (True) or disable (False) streaming responses. context_length: Maximum number of tokens to maintain in conversation history. max_retries: Maximum number of retries for failed API calls. retry_delay: Initial delay (in seconds) for exponential backoff retries. session_enabled: Enable persistent session support. session_id: Optional session ID to load messages from.
Raises: ValueError: If provider is not supported or API key is missing for non-Ollama providers.
"grok": { "sdk": "grok", "base_url": "https://api.x.ai/v1", "api_key": api_key or os.getenv("GROK_API_KEY") or None } }
373 def add_function( 374 self, 375 external_callable: Callable, 376 name: Optional[str] = None, 377 description: Optional[str] = None, 378 override: bool = False, 379 disabled: bool = False, 380 schema_extensions: Optional[Dict[str, Any]] = None 381 ): 382 """ 383 Register a function for LLM tool calling with full type hints and metadata. 384 385 Args: 386 external_callable (Callable): The function to register. 387 name (Optional[str]): Optional custom name. Defaults to function's __name__. 388 description (Optional[str]): Optional custom description. Defaults to first line of docstring. 389 override (bool): If True, replaces an existing tool with the same name. 390 disabled (bool): If True, registers the function in a disabled state. 391 schema_extensions (Optional[Dict[str, Any]]): Optional dictionary mapping parameter names to 392 schema extensions that override or add to the auto-generated schema. 393 394 Raises: 395 ValueError: If the callable is invalid or duplicate name found without override. 396 397 Example: 398 interactor.add_function( 399 my_tool, 400 override=True, 401 disabled=False, 402 schema_extensions={ 403 "param1": {"minimum": 0, "maximum": 100}, 404 "param2": {"format": "email"} 405 } 406 ) 407 """ 408 def _python_type_to_schema(ptype: Any) -> dict: 409 """Convert a Python type annotation to OpenAI-compatible JSON Schema.""" 410 # Handle None case 411 if ptype is None: 412 return {"type": "null"} 413 414 # Get the origin and arguments of the type 415 origin = get_origin(ptype) 416 args = get_args(ptype) 417 418 # Handle Union types (including Optional) 419 if origin is Union: 420 # Check for Optional (Union with None) 421 none_type = type(None) 422 if none_type in args: 423 non_none = [a for a in args if a is not none_type] 424 if len(non_none) == 1: 425 inner = _python_type_to_schema(non_none[0]) 426 inner_copy = inner.copy() 427 inner_copy["nullable"] = True 428 return inner_copy 429 # Multiple types excluding None 430 types = [_python_type_to_schema(a) for a in non_none] 431 return {"anyOf": types, "nullable": True} 432 # Regular Union without None 433 return {"anyOf": [_python_type_to_schema(a) for a in args]} 434 435 # Handle List and similar container types 436 if origin in (list, List): 437 item_type = args[0] if args else Any 438 if item_type is Any: 439 return {"type": "array"} 440 return {"type": "array", "items": _python_type_to_schema(item_type)} 441 442 # Handle Dict types with typing info 443 if origin in (dict, Dict): 444 if not args or len(args) != 2: 445 return {"type": "object"} 446 447 key_type, val_type = args 448 # We can only really use val_type in JSON Schema 449 if val_type is not Any and val_type is not object: 450 return { 451 "type": "object", 452 "additionalProperties": _python_type_to_schema(val_type) 453 } 454 return {"type": "object"} 455 456 # Handle Literal types for enums 457 if origin is Literal: 458 values = args 459 # Try to determine type from values 460 if all(isinstance(v, str) for v in values): 461 return {"type": "string", "enum": list(values)} 462 elif all(isinstance(v, bool) for v in values): 463 return {"type": "boolean", "enum": list(values)} 464 elif all(isinstance(v, (int, float)) for v in values): 465 return {"type": "number", "enum": list(values)} 466 else: 467 # Mixed types, use anyOf 468 return {"anyOf": [{"type": _get_json_type(v), "enum": [v]} for v in values]} 469 470 # Handle basic types 471 if ptype is str: 472 return {"type": "string"} 473 if ptype is int: 474 return {"type": "integer"} 475 if ptype is float: 476 return {"type": "number"} 477 if ptype is bool: 478 return {"type": "boolean"} 479 480 # Handle common datetime types 481 if ptype is datetime: 482 return {"type": "string", "format": "date-time"} 483 if ptype is date: 484 return {"type": "string", "format": "date"} 485 486 # Handle UUID 487 if ptype is uuid.UUID: 488 return {"type": "string", "format": "uuid"} 489 490 # Default to object for any other types 491 return {"type": "object"} 492 493 def _get_json_type(value): 494 """Get the JSON Schema type name for a Python value. 495 496 This helper function maps Python types to their corresponding 497 JSON Schema type names. It handles basic types and provides 498 sensible defaults for complex types. 499 500 Args: 501 value: The Python value to get the JSON type for 502 503 Returns: 504 str: The JSON Schema type name ('string', 'number', 'boolean', 505 'array', 'object', or 'object' as default) 506 """ 507 if isinstance(value, str): 508 return "string" 509 elif isinstance(value, bool): 510 return "boolean" 511 elif isinstance(value, int) or isinstance(value, float): 512 return "number" 513 elif isinstance(value, list): 514 return "array" 515 elif isinstance(value, dict): 516 return "object" 517 else: 518 return "object" # Default 519 520 def _parse_param_docs(docstring: str) -> dict: 521 """Extract parameter descriptions from a docstring.""" 522 if not docstring: 523 return {} 524 525 lines = docstring.splitlines() 526 param_docs = {} 527 current_param = None 528 in_params = False 529 530 # Regular expressions for finding parameter sections and param lines 531 param_section_re = re.compile(r"^(Args|Parameters):\s*$") 532 param_line_re = re.compile(r"^\s{4}(\w+)\s*(?:\([^\)]*\))?:\s*(.*)") 533 534 for line in lines: 535 # Check if we're entering the parameters section 536 if param_section_re.match(line.strip()): 537 in_params = True 538 continue 539 540 if in_params: 541 # Skip empty lines 542 if not line.strip(): 543 continue 544 545 # Check for a parameter definition line 546 match = param_line_re.match(line) 547 if match: 548 current_param = match.group(1) 549 param_docs[current_param] = match.group(2).strip() 550 # Check for continuation of a parameter description 551 elif current_param and line.startswith(" " * 8): 552 param_docs[current_param] += " " + line.strip() 553 # If we see a line that doesn't match our patterns, we're out of the params section 554 else: 555 current_param = None 556 557 return param_docs 558 559 # Start of main function logic 560 561 # Skip if tools are disabled 562 if not self.tools_enabled: 563 return 564 565 # Validate input callable 566 if not external_callable: 567 raise ValueError("A valid external callable must be provided.") 568 569 # Set function name, either from parameter or from callable's __name__ 570 function_name = name or external_callable.__name__ 571 572 # Try to get docstring and extract description 573 try: 574 docstring = inspect.getdoc(external_callable) 575 description = description or (docstring.split("\n")[0].strip() if docstring else "No description provided.") 576 except Exception as e: 577 self._log(f"[TOOL] Warning: Could not extract docstring from {function_name}: {e}", level="warning") 578 docstring = "" 579 description = description or "No description provided." 580 581 # Extract parameter documentation from docstring 582 param_docs = _parse_param_docs(docstring) 583 584 # Handle conflicts with existing functions 585 if override: 586 self.delete_function(function_name) 587 elif any(t["function"]["name"] == function_name for t in self.tools): 588 raise ValueError(f"Function '{function_name}' is already registered. Use override=True to replace.") 589 590 # Try to get function signature for parameter info 591 try: 592 signature = inspect.signature(external_callable) 593 except (ValueError, TypeError) as e: 594 raise ValueError(f"Cannot inspect callable '{function_name}': {e}") 595 596 # Process parameters to build schema 597 properties = {} 598 required = [] 599 600 for param_name, param in signature.parameters.items(): 601 # Skip self, cls parameters for instance/class methods 602 if param_name in ("self", "cls") and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: 603 continue 604 605 # Get parameter annotation, defaulting to Any 606 annotation = param.annotation if param.annotation != inspect.Parameter.empty else Any 607 608 try: 609 # Convert Python type to JSON Schema 610 schema = _python_type_to_schema(annotation) 611 612 # Add description from docstring or create a default one 613 schema["description"] = param_docs.get(param_name, f"{param_name} parameter") 614 615 # Add to properties 616 properties[param_name] = schema 617 618 # If no default value is provided, parameter is required 619 if param.default == inspect.Parameter.empty: 620 required.append(param_name) 621 self._log(f"[TOOL] Parameter '{param_name}' is required", level="debug") 622 else: 623 self._log(f"[TOOL] Parameter '{param_name}' has default value: {param.default}", level="debug") 624 625 except Exception as e: 626 self._log(f"[TOOL] Error processing parameter {param_name} for {function_name}: {e}", level="error") 627 # Add a basic object schema as fallback 628 properties[param_name] = { 629 "type": "string", # Default to string instead of object for better compatibility 630 "description": f"{param_name} parameter (type conversion failed)" 631 } 632 633 # For parameters with no default value, mark as required even if processing failed 634 if param.default == inspect.Parameter.empty: 635 required.append(param_name) 636 self._log(f"[TOOL] Parameter '{param_name}' marked as required despite conversion failure", level="debug") 637 638 # Apply schema extensions if provided 639 if schema_extensions: 640 for param_name, extensions in schema_extensions.items(): 641 if param_name in properties: 642 properties[param_name].update(extensions) 643 644 # Create parameters object with proper placement of 'required' field 645 parameters = { 646 "type": "object", 647 "properties": properties, 648 } 649 650 # Only add required field if there are required parameters 651 if required: 652 parameters["required"] = required 653 654 # Build the final tool specification 655 tool_spec = { 656 "type": "function", 657 "function": { 658 "name": function_name, 659 "description": description, 660 "parameters": parameters 661 } 662 } 663 664 # Set disabled flag if requested 665 if disabled: 666 tool_spec["function"]["disabled"] = True 667 668 # Add to tools list 669 self.tools.append(tool_spec) 670 671 # Make the function available as an attribute on the instance 672 setattr(self, function_name, external_callable) 673 674 # Log the registration with detailed information 675 self._log(f"[TOOL] Registered function '{function_name}' with {len(properties)} parameters", level="info") 676 if required: 677 self._log(f"[TOOL] Required parameters: {required}", level="info") 678 679 return function_name # Return the name for reference
Register a function for LLM tool calling with full type hints and metadata.
Args: external_callable (Callable): The function to register. name (Optional[str]): Optional custom name. Defaults to function's __name__. description (Optional[str]): Optional custom description. Defaults to first line of docstring. override (bool): If True, replaces an existing tool with the same name. disabled (bool): If True, registers the function in a disabled state. schema_extensions (Optional[Dict[str, Any]]): Optional dictionary mapping parameter names to schema extensions that override or add to the auto-generated schema.
Raises: ValueError: If the callable is invalid or duplicate name found without override.
Example: interactor.add_function( my_tool, override=True, disabled=False, schema_extensions={ "param1": {"minimum": 0, "maximum": 100}, "param2": {"format": "email"} } )
682 def disable_function(self, name: str) -> bool: 683 """ 684 Disable a registered tool function by name. 685 686 This marks the function as inactive for tool calling without removing it from the internal registry. 687 The function remains visible in the tool listing but is skipped during tool selection by the LLM. 688 689 Args: 690 name (str): The name of the function to disable. 691 692 Returns: 693 bool: True if the function was found and disabled, False otherwise. 694 695 Example: 696 interactor.disable_function("extract_text") 697 """ 698 for tool in self.tools: 699 if tool["function"]["name"] == name: 700 tool["function"]["disabled"] = True 701 return True 702 return False
Disable a registered tool function by name.
This marks the function as inactive for tool calling without removing it from the internal registry. The function remains visible in the tool listing but is skipped during tool selection by the LLM.
Args: name (str): The name of the function to disable.
Returns: bool: True if the function was found and disabled, False otherwise.
Example: interactor.disable_function("extract_text")
705 def enable_function(self, name: str) -> bool: 706 """ 707 Re-enable a previously disabled tool function by name. 708 709 This removes the 'disabled' flag from a tool function, making it available again for LLM use. 710 711 Args: 712 name (str): The name of the function to enable. 713 714 Returns: 715 bool: True if the function was found and enabled, False otherwise. 716 717 Example: 718 interactor.enable_function("extract_text") 719 """ 720 for tool in self.tools: 721 if tool["function"]["name"] == name: 722 tool["function"].pop("disabled", None) 723 return True 724 return False
Re-enable a previously disabled tool function by name.
This removes the 'disabled' flag from a tool function, making it available again for LLM use.
Args: name (str): The name of the function to enable.
Returns: bool: True if the function was found and enabled, False otherwise.
Example: interactor.enable_function("extract_text")
727 def delete_function(self, name: str) -> bool: 728 """ 729 Permanently remove a registered tool function from the Interactor. 730 731 This deletes both the tool metadata and the callable attribute, making it fully inaccessible 732 from the active session. Useful for dynamically trimming the toolset. 733 734 Args: 735 name (str): The name of the function to delete. 736 737 Returns: 738 bool: True if the function was found and removed, False otherwise. 739 740 Example: 741 interactor.delete_function("extract_text") 742 """ 743 before = len(self.tools) 744 self.tools = [tool for tool in self.tools if tool["function"]["name"] != name] 745 if hasattr(self, name): 746 delattr(self, name) 747 return len(self.tools) < before
Permanently remove a registered tool function from the Interactor.
This deletes both the tool metadata and the callable attribute, making it fully inaccessible from the active session. Useful for dynamically trimming the toolset.
Args: name (str): The name of the function to delete.
Returns: bool: True if the function was found and removed, False otherwise.
Example: interactor.delete_function("extract_text")
750 def list_functions(self) -> List[Dict[str, Any]]: 751 """Get the list of registered functions for tool calling. 752 753 Returns: 754 List[Dict[str, Any]]: List of registered functions. 755 """ 756 return self.tools
Get the list of registered functions for tool calling.
Returns: List[Dict[str, Any]]: List of registered functions.
881 def interact( 882 self, 883 user_input: Optional[str], 884 quiet: bool = False, 885 tools: bool = True, 886 stream: bool = True, 887 markdown: bool = False, 888 model: Optional[str] = None, 889 output_callback: Optional[Callable[[str], None]] = None, 890 session_id: Optional[str] = None, 891 raw: Optional[bool] = None, 892 tool_suppress: bool = True, 893 timeout: float = 60.0 894 ) -> Union[Optional[str], "TokenStream"]: 895 """Main universal gateway for all LLM interaction. 896 897 This function serves as the single entry point for all interactions with the language model. 898 When `raw=False` (default), it handles the interaction internally and returns the full response. 899 When `raw=True`, it returns a context manager that yields chunks of the response for custom handling. 900 901 Args: 902 user_input: Text input from the user. 903 quiet: If True, don't print status info or progress. 904 tools: Enable (True) or disable (False) tool calling. 905 stream: Enable (True) or disable (False) streaming responses. 906 markdown: If True, renders content as markdown. 907 model: Optional model override. 908 output_callback: Optional callback to handle the output. 909 session_id: Optional session ID to load messages from. 910 raw: If True, return a context manager instead of handling the interaction internally. 911 If None, use the class-level setting from __init__. 912 tool_suppress: If True and raw=True, filter out tool-related status messages. 913 timeout: Maximum time in seconds to wait for the stream to complete when raw=True. 914 915 Returns: 916 If raw=False: The complete response from the model as a string, or None if there was an error. 917 If raw=True: A context manager that yields chunks of the response as they arrive. 918 919 Example with default mode: 920 response = ai.interact("Tell me a joke") 921 922 Example with raw mode: 923 with ai.interact("Tell me a joke", raw=True) as stream: 924 for chunk in stream: 925 print(chunk, end="", flush=True) 926 """ 927 if not user_input: 928 return None 929 930 if quiet or self.quiet: 931 markdown = False 932 stream = False 933 934 # Determine if we should use raw mode 935 # If raw parameter is explicitly provided, use that; otherwise use class setting 936 use_raw = self.raw if raw is None else raw 937 938 # If raw mode is requested, delegate to interact_raw 939 if use_raw: 940 return self._interact_raw( 941 user_input=user_input, 942 tools=tools, 943 model=model, 944 session_id=session_id, 945 tool_suppress=tool_suppress, 946 timeout=timeout 947 ) 948 949 # Setup model if specified 950 if model: 951 self._setup_client(model) 952 self._setup_encoding() 953 954 # Session handling 955 if self.session_enabled and session_id: 956 self.session_id = session_id 957 self.session_load(session_id) 958 959 # Add user message using messages_add 960 self.messages_add(role="user", content=user_input) 961 962 # Log token count estimate 963 token_count = self._count_tokens(self.history) 964 if not quiet: 965 print(f"[dim]Estimated tokens in context: {token_count} / {self.context_length}[/dim]") 966 967 # Make sure we have enough context space 968 if token_count > self.context_length: 969 if self._cycle_messages(): 970 if not quiet: 971 print("[red]Context window exceeded. Cannot proceed.[/red]") 972 return None 973 974 # Log user input 975 self._log(f"[USER] {user_input}") 976 977 # Handle the actual interaction with complete streaming for all responses 978 result = asyncio.run(self._interact_async_core( 979 user_input=user_input, 980 quiet=quiet, 981 tools=tools, 982 stream=stream, 983 markdown=markdown, 984 output_callback=output_callback 985 )) 986 987 # Log completion for this interaction 988 self._log(f"[INTERACTION] Completed with {len(self.history)} total messages") 989 990 return result
Main universal gateway for all LLM interaction.
This function serves as the single entry point for all interactions with the language model.
When raw=False
(default), it handles the interaction internally and returns the full response.
When raw=True
, it returns a context manager that yields chunks of the response for custom handling.
Args: user_input: Text input from the user. quiet: If True, don't print status info or progress. tools: Enable (True) or disable (False) tool calling. stream: Enable (True) or disable (False) streaming responses. markdown: If True, renders content as markdown. model: Optional model override. output_callback: Optional callback to handle the output. session_id: Optional session ID to load messages from. raw: If True, return a context manager instead of handling the interaction internally. If None, use the class-level setting from __init__. tool_suppress: If True and raw=True, filter out tool-related status messages. timeout: Maximum time in seconds to wait for the stream to complete when raw=True.
Returns: If raw=False: The complete response from the model as a string, or None if there was an error. If raw=True: A context manager that yields chunks of the response as they arrive.
Example with default mode: response = ai.interact("Tell me a joke")
Example with raw mode: with ai.interact("Tell me a joke", raw=True) as stream: for chunk in stream: print(chunk, end="", flush=True)
2191 def messages_add( 2192 self, 2193 role: str, 2194 content: Any, 2195 tool_info: Optional[Dict] = None, 2196 normalize: bool = True 2197 ) -> str: 2198 """ 2199 Add a message to the standardized session_history and then update SDK-specific history. 2200 2201 This method is the central point for all message additions to the conversation. 2202 2203 Args: 2204 role: The role of the message ("user", "assistant", "system", "tool") 2205 content: The message content (text or structured) 2206 tool_info: Optional tool-related metadata 2207 normalize: Whether to normalize history after adding this message 2208 2209 Returns: 2210 str: Unique ID of the added message 2211 """ 2212 # Generate a unique message ID 2213 message_id = str(uuid.uuid4()) 2214 2215 # Create the standardized message for session_history 2216 timestamp = datetime.now(timezone.utc).isoformat() 2217 2218 # Store system messages directly 2219 if role == "system": 2220 self.system = content 2221 2222 # Create standard format message 2223 standard_message = { 2224 "role": role, 2225 "content": content, 2226 "id": message_id, 2227 "timestamp": timestamp, 2228 "metadata": { 2229 "sdk": self.sdk 2230 } 2231 } 2232 2233 # Add tool info if provided 2234 if tool_info: 2235 standard_message["metadata"]["tool_info"] = tool_info 2236 2237 # Add to session_history 2238 if not hasattr(self, "session_history"): 2239 self.session_history = [] 2240 2241 self.session_history.append(standard_message) 2242 2243 # Save to persistent session if enabled 2244 if self.session_enabled and self.session_id: 2245 # Convert standard message to session-compatible format 2246 session_msg = { 2247 "role": role, 2248 "content": content, 2249 "id": message_id, 2250 "timestamp": timestamp 2251 } 2252 2253 # Add tool-related fields if present 2254 if tool_info: 2255 for key, value in tool_info.items(): 2256 session_msg[key] = value 2257 2258 # Store in session 2259 self.session.msg_insert(self.session_id, session_msg) 2260 2261 # Update the SDK-specific format in self.history by running the normalizer 2262 if normalize: 2263 # We only need to normalize the most recent message for efficiency 2264 # Pass a flag indicating we're just normalizing a new message 2265 self._normalizer(force=False, new_message_only=True) 2266 2267 # Log the added message 2268 self._log(f"[MESSAGE ADDED] {role}: {str(content)[:50]}...") 2269 2270 return message_id
Add a message to the standardized session_history and then update SDK-specific history.
This method is the central point for all message additions to the conversation.
Args: role: The role of the message ("user", "assistant", "system", "tool") content: The message content (text or structured) tool_info: Optional tool-related metadata normalize: Whether to normalize history after adding this message
Returns: str: Unique ID of the added message
2273 def messages_system(self, prompt: str): 2274 """Set or retrieve the current system prompt. 2275 2276 This method manages the system prompt that guides the AI's behavior. 2277 It can be used to both set a new system prompt and retrieve the current one. 2278 When setting a new prompt, it updates the system message in the conversation 2279 history and persists it to the session if enabled. 2280 2281 Args: 2282 prompt (str): The new system prompt to set. If empty or None, 2283 returns the current system prompt without changes. 2284 2285 Returns: 2286 str: The current system prompt after any updates 2287 """ 2288 if not isinstance(prompt, str) or not prompt: 2289 return self.system 2290 2291 # If the prompt hasn't changed, don't do anything 2292 if self.system == prompt: 2293 return self.system 2294 2295 # Update the system prompt 2296 old_system = self.system 2297 self.system = prompt 2298 2299 # For OpenAI, update or insert the system message in history 2300 if self.sdk == "openai": 2301 # Check if there's already a system message 2302 system_index = next((i for i, msg in enumerate(self.history) 2303 if msg.get("role") == "system"), None) 2304 2305 if system_index is not None: 2306 # Update existing system message 2307 self.history[system_index]["content"] = prompt 2308 else: 2309 # Insert new system message at the beginning 2310 self.history.insert(0, {"role": "system", "content": prompt}) 2311 2312 # For Anthropic, system message is not part of history, just save it for API calls 2313 2314 # Log to session only if prompt actually changed 2315 if self.session_enabled and self.session_id and old_system != prompt: 2316 self.session.msg_insert(self.session_id, {"role": "system", "content": prompt}) 2317 2318 return self.system
Set or retrieve the current system prompt.
This method manages the system prompt that guides the AI's behavior. It can be used to both set a new system prompt and retrieve the current one. When setting a new prompt, it updates the system message in the conversation history and persists it to the session if enabled.
Args: prompt (str): The new system prompt to set. If empty or None, returns the current system prompt without changes.
Returns: str: The current system prompt after any updates
2321 def messages(self) -> list: 2322 """Return full session messages (persisted or in-memory).""" 2323 if self.session_enabled and self.session_id: 2324 return self.session.load_full(self.session_id).get("messages", []) 2325 return self.session_history
Return full session messages (persisted or in-memory).
2328 def messages_length(self) -> int: 2329 """Calculate the total token count for the message history.""" 2330 if not self.encoding: 2331 return 0 2332 2333 total_tokens = 0 2334 for message in self.history: 2335 if message.get("content"): 2336 total_tokens += len(self.encoding.encode(message["content"])) 2337 if message.get("tool_calls"): 2338 for tool_call in message["tool_calls"]: 2339 if tool_call.get("function"): 2340 total_tokens += len(self.encoding.encode(tool_call["function"].get("name", ""))) 2341 total_tokens += len(self.encoding.encode(tool_call["function"].get("arguments", ""))) 2342 return total_tokens
Calculate the total token count for the message history.
2345 def session_load(self, session_id: Optional[str]): 2346 """Load and normalize messages for a specific session. 2347 2348 This method loads a conversation session from persistent storage and 2349 normalizes the messages to the current SDK format. It handles system 2350 messages, tool calls, and maintains message ordering. If loading fails, 2351 it resets to an empty session with the default system prompt. 2352 2353 Args: 2354 session_id (Optional[str]): The ID of the session to load. 2355 If None, resets to in-memory mode. 2356 2357 Note: 2358 This method will update both session_history and history to match 2359 the loaded session's state. It also ensures proper SDK-specific 2360 message formatting. 2361 """ 2362 self.session_id = session_id 2363 self._last_session_id = session_id 2364 2365 if self.session_enabled and session_id: 2366 try: 2367 # Load raw session data 2368 session_data = self.session.load_full(session_id) 2369 messages = session_data.get("messages", []) 2370 2371 # Convert session format to our standard format 2372 self.session_history = [] 2373 2374 # Track the most recent system message 2375 latest_system_msg = None 2376 2377 for msg in messages: 2378 # Extract fields 2379 role = msg.get("role", "user") 2380 content = msg.get("content", "") 2381 msg_id = msg.get("id", str(uuid.uuid4())) 2382 timestamp = msg.get("timestamp", datetime.now(timezone.utc).isoformat()) 2383 2384 # If this is a system message, track it but don't add to session_history yet 2385 if role == "system": 2386 if latest_system_msg is None or timestamp > latest_system_msg["timestamp"]: 2387 latest_system_msg = { 2388 "role": role, 2389 "content": content, 2390 "id": msg_id, 2391 "timestamp": timestamp, 2392 "metadata": {"sdk": self.sdk} 2393 } 2394 continue 2395 2396 # Build tool_info if present 2397 tool_info = None 2398 if any(key in msg for key in ["tool_use_id", "tool_call_id", "name", "arguments"]): 2399 tool_info = { 2400 "id": msg.get("tool_use_id") or msg.get("tool_call_id"), 2401 "name": msg.get("name", "unknown_tool"), 2402 "arguments": msg.get("arguments", {}) 2403 } 2404 2405 # Create standard message 2406 standard_msg = { 2407 "role": role, 2408 "content": content, 2409 "id": msg_id, 2410 "timestamp": timestamp, 2411 "metadata": { 2412 "sdk": self.sdk 2413 } 2414 } 2415 2416 if tool_info: 2417 standard_msg["metadata"]["tool_info"] = tool_info 2418 2419 self.session_history.append(standard_msg) 2420 2421 # If we found a system message, update the system property and add to history 2422 if latest_system_msg: 2423 self.system = latest_system_msg["content"] 2424 # Insert at the beginning of session_history 2425 self.session_history.insert(0, latest_system_msg) 2426 else: 2427 # If no system message was found, add the current system message 2428 self.messages_add(role="system", content=self.system) 2429 2430 # Normalize to current SDK format 2431 self._normalizer(force=True) 2432 2433 self._log(f"[SESSION] Switched to session '{session_id}'") 2434 except Exception as e: 2435 self.logger.error(f"Failed to load session '{session_id}': {e}") 2436 self.session_reset() 2437 else: 2438 # Reset to empty state with system message 2439 self.session_reset()
Load and normalize messages for a specific session.
This method loads a conversation session from persistent storage and normalizes the messages to the current SDK format. It handles system messages, tool calls, and maintains message ordering. If loading fails, it resets to an empty session with the default system prompt.
Args: session_id (Optional[str]): The ID of the session to load. If None, resets to in-memory mode.
Note: This method will update both session_history and history to match the loaded session's state. It also ensures proper SDK-specific message formatting.
2442 def session_reset(self): 2443 """Reset the current session state and reinitialize to default system prompt. 2444 2445 This method performs a complete reset of the conversation state: 2446 1. Clears all message history 2447 2. Disables session ID tracking 2448 3. Returns to in-memory mode 2449 4. Reinitializes with the default system prompt 2450 2451 The reset is useful for starting fresh conversations or recovering 2452 from error states. It maintains the basic system configuration while 2453 clearing all conversation context. 2454 """ 2455 self.session_id = None 2456 self._last_session_id = None 2457 2458 # Clear histories 2459 self.session_history = [] 2460 self.history = [] 2461 2462 # Reapply the system message 2463 if hasattr(self, "system") and self.system: 2464 # Add to session_history 2465 self.messages_add(role="system", content=self.system) 2466 else: 2467 # Ensure we have a default system message 2468 self.system = "You are a helpful Assistant." 2469 self.messages_add(role="system", content=self.system) 2470 2471 self._log("[SESSION] Reset to in-memory mode")
Reset the current session state and reinitialize to default system prompt.
This method performs a complete reset of the conversation state:
- Clears all message history
- Disables session ID tracking
- Returns to in-memory mode
- Reinitializes with the default system prompt
The reset is useful for starting fresh conversations or recovering from error states. It maintains the basic system configuration while clearing all conversation context.
2915 def track_token_usage(self): 2916 """Track and return token usage across the conversation history. 2917 2918 This method maintains a history of token usage measurements and provides 2919 current usage statistics. It tracks: 2920 - Current token count 2921 - Context length limit 2922 - Usage percentage 2923 - Historical measurements 2924 - Current provider and model info 2925 2926 Returns: 2927 dict: Dictionary containing: 2928 - current: Current token count 2929 - limit: Maximum context length 2930 - percentage: Usage as percentage of limit 2931 - history: Last 10 measurements with timestamps 2932 - provider: Current provider name 2933 - model: Current model name 2934 2935 Note: 2936 The history is limited to the last 100 measurements to prevent 2937 unbounded memory growth. 2938 """ 2939 if not hasattr(self, "_token_history"): 2940 self._token_history = [] 2941 2942 # Count current tokens 2943 current_count = self._count_tokens(self.history) 2944 2945 # Add to history 2946 timestamp = datetime.now(timezone.utc).isoformat() 2947 self._token_history.append({ 2948 "timestamp": timestamp, 2949 "count": current_count, 2950 "limit": self.context_length, 2951 "provider": self.provider, 2952 "model": self.model 2953 }) 2954 2955 # Keep only the last 100 measurements to avoid unlimited growth 2956 if len(self._token_history) > 100: 2957 self._token_history = self._token_history[-100:] 2958 2959 # Return current tracking info 2960 return { 2961 "current": current_count, 2962 "limit": self.context_length, 2963 "percentage": round((current_count / self.context_length) * 100, 1) if self.context_length else 0, 2964 "history": self._token_history[-10:], # Return last 10 measurements 2965 "provider": self.provider, 2966 "model": self.model 2967 }
Track and return token usage across the conversation history.
This method maintains a history of token usage measurements and provides current usage statistics. It tracks:
- Current token count
- Context length limit
- Usage percentage
- Historical measurements
- Current provider and model info
Returns: dict: Dictionary containing: - current: Current token count - limit: Maximum context length - percentage: Usage as percentage of limit - history: Last 10 measurements with timestamps - provider: Current provider name - model: Current model name
Note: The history is limited to the last 100 measurements to prevent unbounded memory growth.
2970 def get_message_token_breakdown(self): 2971 """Analyze token usage by message type and provide a detailed breakdown. 2972 2973 This method performs a detailed analysis of token usage across the 2974 conversation history, breaking down usage by: 2975 - Message role (system, user, assistant, tool) 2976 - Content type (text, tool calls, tool results) 2977 - Individual message statistics 2978 2979 Returns: 2980 dict: Token usage breakdown containing: 2981 - total: Total tokens used 2982 - by_role: Tokens used by each role 2983 - by_type: Tokens used by content type 2984 - messages: List of individual message stats including: 2985 - index: Message position 2986 - role: Message role 2987 - tokens: Tokens used 2988 - has_tools: Whether message contains tool calls 2989 2990 Note: 2991 This analysis is useful for understanding token usage patterns 2992 and optimizing conversation context. 2993 """ 2994 breakdown = { 2995 "total": 0, 2996 "by_role": { 2997 "system": 0, 2998 "user": 0, 2999 "assistant": 0, 3000 "tool": 0 3001 }, 3002 "by_type": { 3003 "text": 0, 3004 "tool_calls": 0, 3005 "tool_results": 0 3006 }, 3007 "messages": [] 3008 } 3009 3010 # Analyze each message 3011 for i, msg in enumerate(self.history): 3012 msg_tokens = self._count_tokens([msg]) 3013 role = msg.get("role", "unknown") 3014 3015 # Track by role 3016 if role in breakdown["by_role"]: 3017 breakdown["by_role"][role] += msg_tokens 3018 3019 # Track by content type 3020 if role == "assistant" and msg.get("tool_calls"): 3021 breakdown["by_type"]["tool_calls"] += msg_tokens 3022 elif role == "tool": 3023 breakdown["by_type"]["tool_results"] += msg_tokens 3024 else: 3025 breakdown["by_type"]["text"] += msg_tokens 3026 3027 # Add individual message data 3028 breakdown["messages"].append({ 3029 "index": i, 3030 "role": role, 3031 "tokens": msg_tokens, 3032 "has_tools": bool(msg.get("tool_calls") or msg.get("tool_use") or 3033 (isinstance(msg.get("content"), list) and 3034 any(isinstance(c, dict) and c.get("type") in ["tool_use", "tool_result"] 3035 for c in msg.get("content", [])))) 3036 }) 3037 3038 # Update tota? 3039 breakdown["total"] += msg_tokens 3040 3041 return breakdown
Analyze token usage by message type and provide a detailed breakdown.
This method performs a detailed analysis of token usage across the conversation history, breaking down usage by:
- Message role (system, user, assistant, tool)
- Content type (text, tool calls, tool results)
- Individual message statistics
Returns: dict: Token usage breakdown containing: - total: Total tokens used - by_role: Tokens used by each role - by_type: Tokens used by content type - messages: List of individual message stats including: - index: Message position - role: Message role - tokens: Tokens used - has_tools: Whether message contains tool calls
Note: This analysis is useful for understanding token usage patterns and optimizing conversation context.
19class Session: 20 def __init__(self, directory: str = None): 21 """ 22 Initialize the session manager and ensure the session directory exists. 23 24 Args: 25 directory (str): Filesystem path for session storage. Must not be None or empty. 26 27 Raises: 28 ValueError: If directory is None or not a string. 29 OSError: If the directory cannot be created or accessed. 30 """ 31 if not directory: 32 raise ValueError("Session directory must be a valid non-empty string path.") 33 34 try: 35 self.path = Path(os.path.expanduser(directory)) 36 self.path.mkdir(parents=True, exist_ok=True) 37 except Exception as e: 38 raise OSError(f"Failed to initialize session directory '{directory}': {e}") 39 40 # --------------------------- 41 # Core CRUD 42 # --------------------------- 43 44 def list(self) -> List[Dict]: 45 """ 46 Return metadata for all sessions in the directory. 47 48 Returns: 49 List[Dict]: Sorted list of session metadata dictionaries. 50 """ 51 out = [] 52 for file in self.path.glob("*.json"): 53 try: 54 with open(file, "r") as f: 55 d = json.load(f) 56 out.append({ 57 "id": d.get("id"), 58 "name": d.get("name"), 59 "created": d.get("created"), 60 "tags": d.get("tags", []), 61 "summary": d.get("summary") 62 }) 63 except Exception: 64 continue 65 return sorted(out, key=lambda x: x["created"], reverse=True) 66 67 def create(self, name: str, tags: Optional[List[str]] = None) -> str: 68 """ 69 Create and persist a new session. 70 71 Args: 72 name (str): Name of the new session. 73 tags (List[str], optional): Optional list of tags. 74 75 Returns: 76 str: Unique session ID of the new session. 77 """ 78 sid = str(uuid.uuid4()) 79 session = { 80 "id": sid, 81 "name": name, 82 "created": datetime.now(timezone.utc).isoformat(), 83 "parent": None, 84 "branch_point": None, 85 "tags": tags or [], 86 "summary": None, 87 "messages": [] 88 } 89 self._save_file(sid, session) 90 return sid 91 92 def load(self, session_id: str) -> List[Dict]: 93 """ 94 Return OpenAI-compatible message list from a session. 95 96 Filters out internal keys and leaves only standard API-compatible fields. 97 98 Args: 99 session_id (str): ID of the session to load. 100 101 Returns: 102 List[Dict]: List of clean message dictionaries. 103 """ 104 session = self._read_file(session_id) 105 return [ 106 {k: v for k, v in m.items() if k in { 107 "role", "content", "tool_calls", "name", "function_call", "tool_call_id" 108 }} for m in session.get("messages", []) 109 ] 110 111 def load_full(self, session_id: str) -> Dict: 112 """ 113 Return the complete session file as-is. 114 115 Args: 116 session_id (str): ID of the session. 117 118 Returns: 119 Dict: Entire raw session data. 120 """ 121 return self._read_file(session_id) 122 123 def delete(self, session_id: str): 124 """ 125 Delete a session file from disk. 126 127 Args: 128 session_id (str): ID of the session to delete. 129 """ 130 file = self.path / f"{session_id}.json" 131 if file.exists(): 132 file.unlink() 133 134 def update(self, session_id: str, key: str, value: Any): 135 """ 136 Update a top-level key in a session file. 137 138 Args: 139 session_id (str): Session ID. 140 key (str): Field to update. 141 value (Any): New value for the field. 142 """ 143 session = self._read_file(session_id) 144 session[key] = value 145 self._save_file(session_id, session) 146 147 # --------------------------- 148 # Message Operations 149 # --------------------------- 150 151 def msg_insert(self, session_id: str, message: Dict) -> str: 152 """ 153 Insert a new message into a session. 154 155 Args: 156 session_id (str): Session ID. 157 message (Dict): Message dictionary to insert. 158 159 Returns: 160 str: ID of the inserted message. 161 """ 162 session = self._read_file(session_id) 163 entry = { 164 "id": str(uuid.uuid4()), 165 "timestamp": datetime.now(timezone.utc).isoformat(), 166 **message 167 } 168 session["messages"].append(entry) 169 self._save_file(session_id, session) 170 return entry["id"] 171 172 def msg_get(self, session_id: str, message_id: str) -> Optional[Dict]: 173 """ 174 Retrieve a specific message from a session. 175 176 Args: 177 session_id (str): Session ID. 178 message_id (str): ID of the message to retrieve. 179 180 Returns: 181 Optional[Dict]: The message if found, else None. 182 """ 183 session = self._read_file(session_id) 184 for msg in session.get("messages", []): 185 if msg.get("id") == message_id: 186 return msg 187 return None 188 189 def msg_index(self, session_id: str, message_id: str) -> Optional[int]: 190 """ 191 Get the index of a message within a session. 192 193 Args: 194 session_id (str): Session ID. 195 message_id (str): Message ID. 196 197 Returns: 198 Optional[int]: Index if found, else None. 199 """ 200 session = self._read_file(session_id) 201 for i, msg in enumerate(session.get("messages", [])): 202 if msg.get("id") == message_id: 203 return i 204 return None 205 206 def msg_update(self, session_id: str, message_id: str, new_content: str) -> bool: 207 """ 208 Update the content of a specific message. 209 210 Args: 211 session_id (str): Session ID. 212 message_id (str): Message ID. 213 new_content (str): New content for the message. 214 215 Returns: 216 bool: True if update succeeded, False otherwise. 217 """ 218 session = self._read_file(session_id) 219 for m in session["messages"]: 220 if m.get("id") == message_id: 221 m["content"] = new_content 222 self._save_file(session_id, session) 223 return True 224 return False 225 226 def msg_delete(self, session_id: str, message_id: str) -> bool: 227 """ 228 Delete a message from a session. 229 230 Args: 231 session_id (str): Session ID. 232 message_id (str): Message ID. 233 234 Returns: 235 bool: True if deletion occurred, False otherwise. 236 """ 237 session = self._read_file(session_id) 238 before = len(session["messages"]) 239 session["messages"] = [m for m in session["messages"] if m.get("id") != message_id] 240 self._save_file(session_id, session) 241 return len(session["messages"]) < before 242 243 # --------------------------- 244 # Branching & Summarization 245 # --------------------------- 246 247 def branch(self, from_id: str, message_id: str, new_name: str) -> str: 248 """Create a new session by branching from a specific message. 249 250 This method creates a new session that branches from an existing one at a specific 251 message point. The new session inherits all messages up to and including the 252 specified message, then starts fresh from there. 253 254 Args: 255 from_id (str): ID of the source session to branch from. 256 message_id (str): ID of the message to branch at. 257 new_name (str): Name for the new branched session. 258 259 Returns: 260 str: ID of the newly created branched session. 261 262 Raises: 263 ValueError: If the source session or message ID is not found. 264 """ 265 # Get source session 266 source = self._read_file(from_id) 267 if not source: 268 raise ValueError(f"Source session '{from_id}' not found") 269 270 # Find the branch point 271 branch_index = self.msg_index(from_id, message_id) 272 if branch_index is None: 273 raise ValueError(f"Message '{message_id}' not found in session '{from_id}'") 274 275 # Create new session 276 new_id = self.create(new_name, source.get("tags", [])) 277 new_session = self._read_file(new_id) 278 279 # Copy messages up to branch point 280 new_session["messages"] = source["messages"][:branch_index + 1] 281 new_session["parent"] = from_id 282 new_session["branch_point"] = message_id 283 284 # Save and return 285 self._save_file(new_id, new_session) 286 return new_id 287 288 def summarize(self, interactor, session_id: str) -> str: 289 """Generate a summary of the session using the provided interactor. 290 291 This method uses the AI interactor to analyze the session content and generate 292 a concise summary. The summary is stored in the session metadata and returned. 293 294 Args: 295 interactor: An AI interactor instance capable of generating summaries. 296 session_id (str): ID of the session to summarize. 297 298 Returns: 299 str: The generated summary text. 300 301 Note: 302 The summary is automatically stored in the session metadata and can be 303 retrieved later using load_full(). 304 """ 305 session = self._read_file(session_id) 306 if not session: 307 return "" 308 309 # Get clean message list 310 messages = self.load(session_id) 311 if not messages: 312 return "" 313 314 # Generate summary 315 summary = interactor.interact( 316 "Summarize this conversation in 2-3 sentences:", 317 tools=False, 318 stream=False, 319 markdown=False 320 ) 321 322 # Store and return 323 session["summary"] = summary 324 self._save_file(session_id, session) 325 return summary 326 327 # --------------------------- 328 # Search Capabilities 329 # --------------------------- 330 331 def search(self, query: str, session_id: Optional[str] = None) -> List[Dict]: 332 """Search for messages containing the query text within a session or all sessions. 333 334 This method performs a case-insensitive text search across message content. 335 If a session_id is provided, only searches within that session. Otherwise, 336 searches across all sessions. 337 338 Args: 339 query (str): Text to search for. 340 session_id (Optional[str]): Optional session ID to limit search scope. 341 342 Returns: 343 List[Dict]: List of matching messages with their session context. 344 Each dict contains: 345 - session_id: ID of the containing session 346 - message: The matching message 347 - context: Surrounding messages for context 348 """ 349 results = [] 350 query = query.lower() 351 352 # Determine search scope 353 if session_id: 354 sessions = [(session_id, self._read_file(session_id))] 355 else: 356 sessions = [(f.stem, self._read_file(f.stem)) for f in self.path.glob("*.json")] 357 358 # Search each session 359 for sid, session in sessions: 360 if not session: 361 continue 362 363 messages = session.get("messages", []) 364 for i, msg in enumerate(messages): 365 content = str(msg.get("content", "")).lower() 366 if query in content: 367 # Get context (2 messages before and after) 368 start = max(0, i - 2) 369 end = min(len(messages), i + 3) 370 context = messages[start:end] 371 372 results.append({ 373 "session_id": sid, 374 "message": msg, 375 "context": context 376 }) 377 378 return results 379 380 def search_meta(self, query: str) -> List[Dict]: 381 """Search session metadata (name, tags, summary) for matching sessions. 382 383 This method performs a case-insensitive search across session metadata fields 384 including name, tags, and summary. It returns matching sessions with their 385 full metadata. 386 387 Args: 388 query (str): Text to search for in metadata. 389 390 Returns: 391 List[Dict]: List of matching session metadata dictionaries. 392 Each dict contains: 393 - id: Session ID 394 - name: Session name 395 - created: Creation timestamp 396 - tags: List of tags 397 - summary: Session summary if available 398 """ 399 results = [] 400 query = query.lower() 401 402 for file in self.path.glob("*.json"): 403 try: 404 with open(file, "r") as f: 405 session = json.load(f) 406 407 # Check metadata fields 408 name = str(session.get("name", "")).lower() 409 tags = [str(t).lower() for t in session.get("tags", [])] 410 summary = str(session.get("summary", "")).lower() 411 412 if (query in name or 413 any(query in tag for tag in tags) or 414 query in summary): 415 results.append({ 416 "id": session.get("id"), 417 "name": session.get("name"), 418 "created": session.get("created"), 419 "tags": session.get("tags", []), 420 "summary": session.get("summary") 421 }) 422 except Exception: 423 continue 424 425 return sorted(results, key=lambda x: x["created"], reverse=True) 426 427 # --------------------------- 428 # Internal I/O 429 # --------------------------- 430 431 def _read_file(self, session_id: str) -> Dict: 432 """Read and parse a session file from disk. 433 434 This internal method handles reading and parsing session files. 435 It ensures proper error handling and returns an empty session 436 structure if the file doesn't exist or is invalid. 437 438 Args: 439 session_id (str): ID of the session to read. 440 441 Returns: 442 Dict: Session data dictionary or empty session structure. 443 """ 444 file = self.path / f"{session_id}.json" 445 if not file.exists(): 446 return { 447 "id": session_id, 448 "name": "New Session", 449 "created": datetime.now(timezone.utc).isoformat(), 450 "messages": [] 451 } 452 453 try: 454 with open(file, "r") as f: 455 return json.load(f) 456 except Exception: 457 return { 458 "id": session_id, 459 "name": "New Session", 460 "created": datetime.now(timezone.utc).isoformat(), 461 "messages": [] 462 } 463 464 def _save_file(self, session_id: str, data: Dict): 465 """Write session data to disk. 466 467 This internal method handles writing session data to disk. 468 It ensures proper error handling and atomic writes. 469 470 Args: 471 session_id (str): ID of the session to save. 472 data (Dict): Session data to write. 473 474 Raises: 475 OSError: If the file cannot be written. 476 """ 477 file = self.path / f"{session_id}.json" 478 temp_file = file.with_suffix(".tmp") 479 480 try: 481 # Write to temporary file first 482 with open(temp_file, "w") as f: 483 json.dump(data, f, indent=2) 484 485 # Atomic rename 486 temp_file.replace(file) 487 except Exception as e: 488 if temp_file.exists(): 489 temp_file.unlink() 490 raise OSError(f"Failed to save session '{session_id}': {e}")
20 def __init__(self, directory: str = None): 21 """ 22 Initialize the session manager and ensure the session directory exists. 23 24 Args: 25 directory (str): Filesystem path for session storage. Must not be None or empty. 26 27 Raises: 28 ValueError: If directory is None or not a string. 29 OSError: If the directory cannot be created or accessed. 30 """ 31 if not directory: 32 raise ValueError("Session directory must be a valid non-empty string path.") 33 34 try: 35 self.path = Path(os.path.expanduser(directory)) 36 self.path.mkdir(parents=True, exist_ok=True) 37 except Exception as e: 38 raise OSError(f"Failed to initialize session directory '{directory}': {e}")
Initialize the session manager and ensure the session directory exists.
Args: directory (str): Filesystem path for session storage. Must not be None or empty.
Raises: ValueError: If directory is None or not a string. OSError: If the directory cannot be created or accessed.
44 def list(self) -> List[Dict]: 45 """ 46 Return metadata for all sessions in the directory. 47 48 Returns: 49 List[Dict]: Sorted list of session metadata dictionaries. 50 """ 51 out = [] 52 for file in self.path.glob("*.json"): 53 try: 54 with open(file, "r") as f: 55 d = json.load(f) 56 out.append({ 57 "id": d.get("id"), 58 "name": d.get("name"), 59 "created": d.get("created"), 60 "tags": d.get("tags", []), 61 "summary": d.get("summary") 62 }) 63 except Exception: 64 continue 65 return sorted(out, key=lambda x: x["created"], reverse=True)
Return metadata for all sessions in the directory.
Returns: List[Dict]: Sorted list of session metadata dictionaries.
67 def create(self, name: str, tags: Optional[List[str]] = None) -> str: 68 """ 69 Create and persist a new session. 70 71 Args: 72 name (str): Name of the new session. 73 tags (List[str], optional): Optional list of tags. 74 75 Returns: 76 str: Unique session ID of the new session. 77 """ 78 sid = str(uuid.uuid4()) 79 session = { 80 "id": sid, 81 "name": name, 82 "created": datetime.now(timezone.utc).isoformat(), 83 "parent": None, 84 "branch_point": None, 85 "tags": tags or [], 86 "summary": None, 87 "messages": [] 88 } 89 self._save_file(sid, session) 90 return sid
Create and persist a new session.
Args: name (str): Name of the new session. tags (List[str], optional): Optional list of tags.
Returns: str: Unique session ID of the new session.
92 def load(self, session_id: str) -> List[Dict]: 93 """ 94 Return OpenAI-compatible message list from a session. 95 96 Filters out internal keys and leaves only standard API-compatible fields. 97 98 Args: 99 session_id (str): ID of the session to load. 100 101 Returns: 102 List[Dict]: List of clean message dictionaries. 103 """ 104 session = self._read_file(session_id) 105 return [ 106 {k: v for k, v in m.items() if k in { 107 "role", "content", "tool_calls", "name", "function_call", "tool_call_id" 108 }} for m in session.get("messages", []) 109 ]
Return OpenAI-compatible message list from a session.
Filters out internal keys and leaves only standard API-compatible fields.
Args: session_id (str): ID of the session to load.
Returns: List[Dict]: List of clean message dictionaries.
111 def load_full(self, session_id: str) -> Dict: 112 """ 113 Return the complete session file as-is. 114 115 Args: 116 session_id (str): ID of the session. 117 118 Returns: 119 Dict: Entire raw session data. 120 """ 121 return self._read_file(session_id)
Return the complete session file as-is.
Args: session_id (str): ID of the session.
Returns: Dict: Entire raw session data.
123 def delete(self, session_id: str): 124 """ 125 Delete a session file from disk. 126 127 Args: 128 session_id (str): ID of the session to delete. 129 """ 130 file = self.path / f"{session_id}.json" 131 if file.exists(): 132 file.unlink()
Delete a session file from disk.
Args: session_id (str): ID of the session to delete.
134 def update(self, session_id: str, key: str, value: Any): 135 """ 136 Update a top-level key in a session file. 137 138 Args: 139 session_id (str): Session ID. 140 key (str): Field to update. 141 value (Any): New value for the field. 142 """ 143 session = self._read_file(session_id) 144 session[key] = value 145 self._save_file(session_id, session)
Update a top-level key in a session file.
Args: session_id (str): Session ID. key (str): Field to update. value (Any): New value for the field.
151 def msg_insert(self, session_id: str, message: Dict) -> str: 152 """ 153 Insert a new message into a session. 154 155 Args: 156 session_id (str): Session ID. 157 message (Dict): Message dictionary to insert. 158 159 Returns: 160 str: ID of the inserted message. 161 """ 162 session = self._read_file(session_id) 163 entry = { 164 "id": str(uuid.uuid4()), 165 "timestamp": datetime.now(timezone.utc).isoformat(), 166 **message 167 } 168 session["messages"].append(entry) 169 self._save_file(session_id, session) 170 return entry["id"]
Insert a new message into a session.
Args: session_id (str): Session ID. message (Dict): Message dictionary to insert.
Returns: str: ID of the inserted message.
172 def msg_get(self, session_id: str, message_id: str) -> Optional[Dict]: 173 """ 174 Retrieve a specific message from a session. 175 176 Args: 177 session_id (str): Session ID. 178 message_id (str): ID of the message to retrieve. 179 180 Returns: 181 Optional[Dict]: The message if found, else None. 182 """ 183 session = self._read_file(session_id) 184 for msg in session.get("messages", []): 185 if msg.get("id") == message_id: 186 return msg 187 return None
Retrieve a specific message from a session.
Args: session_id (str): Session ID. message_id (str): ID of the message to retrieve.
Returns: Optional[Dict]: The message if found, else None.
189 def msg_index(self, session_id: str, message_id: str) -> Optional[int]: 190 """ 191 Get the index of a message within a session. 192 193 Args: 194 session_id (str): Session ID. 195 message_id (str): Message ID. 196 197 Returns: 198 Optional[int]: Index if found, else None. 199 """ 200 session = self._read_file(session_id) 201 for i, msg in enumerate(session.get("messages", [])): 202 if msg.get("id") == message_id: 203 return i 204 return None
Get the index of a message within a session.
Args: session_id (str): Session ID. message_id (str): Message ID.
Returns: Optional[int]: Index if found, else None.
206 def msg_update(self, session_id: str, message_id: str, new_content: str) -> bool: 207 """ 208 Update the content of a specific message. 209 210 Args: 211 session_id (str): Session ID. 212 message_id (str): Message ID. 213 new_content (str): New content for the message. 214 215 Returns: 216 bool: True if update succeeded, False otherwise. 217 """ 218 session = self._read_file(session_id) 219 for m in session["messages"]: 220 if m.get("id") == message_id: 221 m["content"] = new_content 222 self._save_file(session_id, session) 223 return True 224 return False
Update the content of a specific message.
Args: session_id (str): Session ID. message_id (str): Message ID. new_content (str): New content for the message.
Returns: bool: True if update succeeded, False otherwise.
226 def msg_delete(self, session_id: str, message_id: str) -> bool: 227 """ 228 Delete a message from a session. 229 230 Args: 231 session_id (str): Session ID. 232 message_id (str): Message ID. 233 234 Returns: 235 bool: True if deletion occurred, False otherwise. 236 """ 237 session = self._read_file(session_id) 238 before = len(session["messages"]) 239 session["messages"] = [m for m in session["messages"] if m.get("id") != message_id] 240 self._save_file(session_id, session) 241 return len(session["messages"]) < before
Delete a message from a session.
Args: session_id (str): Session ID. message_id (str): Message ID.
Returns: bool: True if deletion occurred, False otherwise.
247 def branch(self, from_id: str, message_id: str, new_name: str) -> str: 248 """Create a new session by branching from a specific message. 249 250 This method creates a new session that branches from an existing one at a specific 251 message point. The new session inherits all messages up to and including the 252 specified message, then starts fresh from there. 253 254 Args: 255 from_id (str): ID of the source session to branch from. 256 message_id (str): ID of the message to branch at. 257 new_name (str): Name for the new branched session. 258 259 Returns: 260 str: ID of the newly created branched session. 261 262 Raises: 263 ValueError: If the source session or message ID is not found. 264 """ 265 # Get source session 266 source = self._read_file(from_id) 267 if not source: 268 raise ValueError(f"Source session '{from_id}' not found") 269 270 # Find the branch point 271 branch_index = self.msg_index(from_id, message_id) 272 if branch_index is None: 273 raise ValueError(f"Message '{message_id}' not found in session '{from_id}'") 274 275 # Create new session 276 new_id = self.create(new_name, source.get("tags", [])) 277 new_session = self._read_file(new_id) 278 279 # Copy messages up to branch point 280 new_session["messages"] = source["messages"][:branch_index + 1] 281 new_session["parent"] = from_id 282 new_session["branch_point"] = message_id 283 284 # Save and return 285 self._save_file(new_id, new_session) 286 return new_id
Create a new session by branching from a specific message.
This method creates a new session that branches from an existing one at a specific message point. The new session inherits all messages up to and including the specified message, then starts fresh from there.
Args: from_id (str): ID of the source session to branch from. message_id (str): ID of the message to branch at. new_name (str): Name for the new branched session.
Returns: str: ID of the newly created branched session.
Raises: ValueError: If the source session or message ID is not found.
288 def summarize(self, interactor, session_id: str) -> str: 289 """Generate a summary of the session using the provided interactor. 290 291 This method uses the AI interactor to analyze the session content and generate 292 a concise summary. The summary is stored in the session metadata and returned. 293 294 Args: 295 interactor: An AI interactor instance capable of generating summaries. 296 session_id (str): ID of the session to summarize. 297 298 Returns: 299 str: The generated summary text. 300 301 Note: 302 The summary is automatically stored in the session metadata and can be 303 retrieved later using load_full(). 304 """ 305 session = self._read_file(session_id) 306 if not session: 307 return "" 308 309 # Get clean message list 310 messages = self.load(session_id) 311 if not messages: 312 return "" 313 314 # Generate summary 315 summary = interactor.interact( 316 "Summarize this conversation in 2-3 sentences:", 317 tools=False, 318 stream=False, 319 markdown=False 320 ) 321 322 # Store and return 323 session["summary"] = summary 324 self._save_file(session_id, session) 325 return summary
Generate a summary of the session using the provided interactor.
This method uses the AI interactor to analyze the session content and generate a concise summary. The summary is stored in the session metadata and returned.
Args: interactor: An AI interactor instance capable of generating summaries. session_id (str): ID of the session to summarize.
Returns: str: The generated summary text.
Note: The summary is automatically stored in the session metadata and can be retrieved later using load_full().
331 def search(self, query: str, session_id: Optional[str] = None) -> List[Dict]: 332 """Search for messages containing the query text within a session or all sessions. 333 334 This method performs a case-insensitive text search across message content. 335 If a session_id is provided, only searches within that session. Otherwise, 336 searches across all sessions. 337 338 Args: 339 query (str): Text to search for. 340 session_id (Optional[str]): Optional session ID to limit search scope. 341 342 Returns: 343 List[Dict]: List of matching messages with their session context. 344 Each dict contains: 345 - session_id: ID of the containing session 346 - message: The matching message 347 - context: Surrounding messages for context 348 """ 349 results = [] 350 query = query.lower() 351 352 # Determine search scope 353 if session_id: 354 sessions = [(session_id, self._read_file(session_id))] 355 else: 356 sessions = [(f.stem, self._read_file(f.stem)) for f in self.path.glob("*.json")] 357 358 # Search each session 359 for sid, session in sessions: 360 if not session: 361 continue 362 363 messages = session.get("messages", []) 364 for i, msg in enumerate(messages): 365 content = str(msg.get("content", "")).lower() 366 if query in content: 367 # Get context (2 messages before and after) 368 start = max(0, i - 2) 369 end = min(len(messages), i + 3) 370 context = messages[start:end] 371 372 results.append({ 373 "session_id": sid, 374 "message": msg, 375 "context": context 376 }) 377 378 return results
Search for messages containing the query text within a session or all sessions.
This method performs a case-insensitive text search across message content. If a session_id is provided, only searches within that session. Otherwise, searches across all sessions.
Args: query (str): Text to search for. session_id (Optional[str]): Optional session ID to limit search scope.
Returns: List[Dict]: List of matching messages with their session context. Each dict contains: - session_id: ID of the containing session - message: The matching message - context: Surrounding messages for context
380 def search_meta(self, query: str) -> List[Dict]: 381 """Search session metadata (name, tags, summary) for matching sessions. 382 383 This method performs a case-insensitive search across session metadata fields 384 including name, tags, and summary. It returns matching sessions with their 385 full metadata. 386 387 Args: 388 query (str): Text to search for in metadata. 389 390 Returns: 391 List[Dict]: List of matching session metadata dictionaries. 392 Each dict contains: 393 - id: Session ID 394 - name: Session name 395 - created: Creation timestamp 396 - tags: List of tags 397 - summary: Session summary if available 398 """ 399 results = [] 400 query = query.lower() 401 402 for file in self.path.glob("*.json"): 403 try: 404 with open(file, "r") as f: 405 session = json.load(f) 406 407 # Check metadata fields 408 name = str(session.get("name", "")).lower() 409 tags = [str(t).lower() for t in session.get("tags", [])] 410 summary = str(session.get("summary", "")).lower() 411 412 if (query in name or 413 any(query in tag for tag in tags) or 414 query in summary): 415 results.append({ 416 "id": session.get("id"), 417 "name": session.get("name"), 418 "created": session.get("created"), 419 "tags": session.get("tags", []), 420 "summary": session.get("summary") 421 }) 422 except Exception: 423 continue 424 425 return sorted(results, key=lambda x: x["created"], reverse=True)
Search session metadata (name, tags, summary) for matching sessions.
This method performs a case-insensitive search across session metadata fields including name, tags, and summary. It returns matching sessions with their full metadata.
Args: query (str): Text to search for in metadata.
Returns: List[Dict]: List of matching session metadata dictionaries. Each dict contains: - id: Session ID - name: Session name - created: Creation timestamp - tags: List of tags - summary: Session summary if available