interactor
64class Interactor: 65 def __init__( 66 self, 67 base_url: Optional[str] = None, 68 api_key: Optional[str] = None, 69 model: str = "openai:gpt-4o-mini", 70 fallback_model = "ollama:mistral-nemo:latest", 71 tools: Optional[bool] = True, 72 stream: bool = True, 73 quiet: bool = False, 74 context_length: int = 128000, 75 max_retries: int = 3, 76 retry_delay: float = 1.0, 77 log_path: Optional[str] = None, 78 raw: Optional[bool] = False, 79 session_enabled: bool = False, 80 session_id: Optional[str] = None, 81 session_path: Optional[str] = None 82 ): 83 """Initialize the universal AI interaction client. 84 85 Args: 86 base_url: Optional base URL for the API. If None, uses the provider's default URL. 87 api_key: Optional API key. If None, attempts to use environment variables based on provider. 88 model: Model identifier in format "provider:model_name" (e.g., "openai:gpt-4o-mini"). 89 tools: Enable (True) or disable (False) tool calling; None for auto-detection based on model support. 90 stream: Enable (True) or disable (False) streaming responses. 91 context_length: Maximum number of tokens to maintain in conversation history. 92 max_retries: Maximum number of retries for failed API calls. 93 retry_delay: Initial delay (in seconds) for exponential backoff retries. 94 session_enabled: Enable persistent session support. 95 session_id: Optional session ID to load messages from. 96 97 Raises: 98 ValueError: If provider is not supported or API key is missing for non-Ollama providers. 99 """ 100 self.system = "You are a helpful Assistant." 101 self.raw = raw 102 self.quiet = quiet 103 self.logger = logging.getLogger(f"InteractorLogger_{id(self)}") 104 self.logger.setLevel(logging.DEBUG) 105 self.providers = { 106 "openai": { 107 "sdk": "openai", 108 "base_url": "https://api.openai.com/v1", 109 "api_key": api_key or os.getenv("OPENAI_API_KEY") or None 110 }, 111 "ollama": { 112 "sdk": "openai", 113 "base_url": "http://localhost:11434/v1", 114 "api_key": api_key or "ollama" 115 }, 116 "nvidia": { 117 "sdk": "openai", 118 "base_url": "https://integrate.api.nvidia.com/v1", 119 "api_key": api_key or os.getenv("NVIDIA_API_KEY") or None 120 }, 121 "google": { 122 "sdk": "openai", 123 "base_url": "https://generativelanguage.googleapis.com/v1beta/openai", 124 "api_key": api_key or os.getenv("GEMINI_API_KEY") or None 125 }, 126 "anthropic": { 127 "sdk": "anthropic", 128 "base_url": "https://api.anthropic.com/v1", 129 "api_key": api_key or os.getenv("ANTHROPIC_API_KEY") or None 130 }, 131 "mistral": { 132 "sdk": "openai", 133 "base_url": "https://api.mistral.ai/v1", 134 "api_key": api_key or os.getenv("MISTRAL_API_KEY") or None 135 }, 136 "deepseek": { 137 "sdk": "openai", 138 "base_url": "https://api.deepseek.com", 139 "api_key": api_key or os.getenv("DEEPSEEK_API_KEY") or None 140 }, 141 "grok": { 142 "sdk": "grok", 143 "base_url": "https://api.x.ai/v1", 144 "api_key": api_key or os.getenv("GROK_API_KEY") or None 145 } 146 } 147 148 149 # Console log handler (always enabled at WARNING+) 150 if not self.logger.handlers: 151 console_handler = logging.StreamHandler(sys.stdout) 152 console_handler.setLevel(logging.WARNING) 153 console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) 154 self.logger.addHandler(console_handler) 155 156 self._log_enabled = False 157 if log_path: 158 file_handler = logging.FileHandler(log_path) 159 file_handler.setLevel(logging.DEBUG) 160 file_handler.setFormatter(logging.Formatter( 161 "%(asctime)s - %(levelname)s - %(message)s", 162 datefmt="%Y-%m-%d %H:%M:%S" 163 )) 164 self.logger.addHandler(file_handler) 165 self._log_enabled = True 166 167 168 self.token_estimate = 0 169 self.last_token_estimate = 0 170 self.stream = stream 171 self.tools = [] 172 self.session_history = [] 173 self.history = [] 174 self.context_length = context_length 175 self.encoding = None 176 self.max_retries = max_retries 177 self.retry_delay = retry_delay 178 self.reveal_tool = [] 179 self.fallback_model = fallback_model 180 self.sdk = None 181 182 # Session support 183 self.session_enabled = session_enabled 184 self.session_id = session_id 185 self._last_session_id = session_id 186 self.session = Session(directory=session_path) if session_enabled else None 187 188 189 if model is None: 190 model = "openai:gpt-4o-mini" 191 192 # Initialize model + encoding 193 self._setup_client(model, base_url, api_key) 194 self.tools_enabled = self.tools_supported if tools is None else tools and self.tools_supported 195 self._setup_encoding() 196 self.messages_add(role="system", content=self.system) 197 198 199 def _log(self, message: str, level: str = "info"): 200 """Log a message to the configured logging handlers. 201 202 This internal method handles logging to both console and file handlers 203 if configured. It respects the logging level and only logs if logging 204 is enabled. 205 206 Args: 207 message (str): The message to log 208 level (str): Logging level - one of "debug", "info", "warning", "error" 209 """ 210 if self._log_enabled: 211 getattr(self.logger, level)(message) 212 213 214 def _setup_client( 215 self, 216 model: Optional[str] = None, 217 base_url: Optional[str] = None, 218 api_key: Optional[str] = None 219 ): 220 """Initialize or reconfigure the Interactor for the given model and SDK. 221 222 Ensures idempotent setup, assigns SDK-specific clients and tool handling logic, 223 and normalizes history to match the provider-specific message schema. 224 """ 225 if not model: 226 raise ValueError("Model must be specified as 'provider:model_name'") 227 228 provider, model_name = model.split(":", 1) 229 230 if not hasattr(self, "session_history"): 231 self.session_history = [] 232 233 # Skip setup if nothing has changed (client may not yet exist on first call) 234 if ( 235 hasattr(self, "client") 236 and self.client 237 and self.provider == provider 238 and self.model == model_name 239 and self.base_url == (base_url or self.base_url) 240 ): 241 return 242 243 if provider not in self.providers: 244 raise ValueError(f"Unsupported provider: {provider}. Supported: {list(self.providers.keys())}") 245 246 # Load provider configuration 247 provider_config = self.providers[provider] 248 self.sdk = provider_config.get("sdk", "openai") 249 self.provider = provider 250 self.model = model_name 251 self.base_url = base_url or provider_config["base_url"] 252 effective_api_key = api_key or provider_config["api_key"] 253 254 if not effective_api_key and provider != "ollama": 255 raise ValueError(f"API key not provided and not found in environment for {provider.upper()}_API_KEY") 256 257 # SDK-specific configuration 258 if self.sdk == "openai": 259 self.client = openai.OpenAI(base_url=self.base_url, api_key=effective_api_key) 260 self.async_client = openai.AsyncOpenAI(base_url=self.base_url, api_key=effective_api_key) 261 self.sdk_runner = self._openai_runner 262 self.tool_key = "tool_call_id" 263 264 elif self.sdk == "anthropic": 265 self.client = anthropic.Anthropic(api_key=effective_api_key) 266 self.async_client = anthropic.AsyncAnthropic(api_key=effective_api_key) 267 self.sdk_runner = self._anthropic_runner 268 self.tool_key = "tool_use_id" 269 270 else: 271 raise ValueError(f"Unsupported SDK type: {self.sdk}") 272 273 # Determine tool support 274 self.tools_supported = self._check_tool_support() 275 if not self.tools_supported: 276 self.logger.warning(f"Tool calling not supported for {provider}:{model_name}") 277 278 # Normalize session history to match SDK after any provider/model change 279 self._normalizer(force=True) 280 281 self._log(f"[MODEL] Switched to {provider}:{model_name}") 282 283 284 def _check_tool_support(self) -> bool: 285 """Determine if the current model supports tool calling. 286 287 Returns: 288 bool: True if tools are supported for the active provider/model, False otherwise. 289 """ 290 try: 291 if self.sdk == "openai": 292 response = self.client.chat.completions.create( 293 model=self.model, 294 messages=[{"role": "user", "content": "Test tool support."}], 295 stream=False, 296 tools=[{ 297 "type": "function", 298 "function": { 299 "name": "test_tool", 300 "description": "Check tool support", 301 "parameters": { 302 "type": "object", 303 "properties": { 304 "query": {"type": "string"} 305 }, 306 "required": ["query"] 307 } 308 } 309 }], 310 tool_choice="auto" 311 ) 312 message = response.choices[0].message 313 return bool(message.tool_calls and len(message.tool_calls) > 0) 314 315 elif self.sdk == "anthropic": 316 # For Claude models, we pre-define support based on model ID 317 # Known tool-supporting Claude models 318 claude_tool_models = ["claude-3-opus", "claude-3-sonnet", "claude-3-haiku", 319 "claude-3.5-sonnet", "claude-3.7-sonnet"] 320 321 # Check if the current model supports tools 322 for supported_model in claude_tool_models: 323 if supported_model in self.model.lower(): 324 self._log(f"[TOOLS] Anthropic model {self.model} is known to support tools") 325 return True 326 327 # If not explicitly supported, try to test 328 try: 329 _ = self.client.messages.create( 330 model=self.model, 331 messages=[{"role": "user", "content": "What's the weather?"}], 332 tools=[{ 333 "name": "test_tool", 334 "description": "Check tool support", 335 "input_schema": { 336 "type": "object", 337 "properties": { 338 "query": {"type": "string"} 339 }, 340 "required": ["query"] 341 } 342 }], 343 max_tokens=10 344 ) 345 return True 346 except anthropic.BadRequestError as e: 347 error_msg = str(e).lower() 348 if "tool" in error_msg and "not supported" in error_msg: 349 self._log(f"[TOOLS] Anthropic model {self.model} does not support tools: {e}") 350 return False 351 if "not a supported tool field" in error_msg: 352 self._log(f"[TOOLS] Anthropic API rejected tool format: {e}") 353 return False 354 raise 355 except Exception as e: 356 self._log(f"[TOOLS] Unexpected error testing tool support: {e}", level="error") 357 return False 358 359 else: 360 self.logger.warning(f"Tool support check not implemented for SDK '{self.sdk}'") 361 return False 362 363 except Exception as e: 364 self.logger.error(f"Tool support check failed for {self.provider}:{self.model} — {e}") 365 return False 366 367 368 def add_function( 369 self, 370 external_callable: Callable, 371 name: Optional[str] = None, 372 description: Optional[str] = None, 373 override: bool = False, 374 disabled: bool = False, 375 schema_extensions: Optional[Dict[str, Any]] = None 376 ): 377 """ 378 Register a function for LLM tool calling with full type hints and metadata. 379 380 Args: 381 external_callable (Callable): The function to register. 382 name (Optional[str]): Optional custom name. Defaults to function's __name__. 383 description (Optional[str]): Optional custom description. Defaults to first line of docstring. 384 override (bool): If True, replaces an existing tool with the same name. 385 disabled (bool): If True, registers the function in a disabled state. 386 schema_extensions (Optional[Dict[str, Any]]): Optional dictionary mapping parameter names to 387 schema extensions that override or add to the auto-generated schema. 388 389 Raises: 390 ValueError: If the callable is invalid or duplicate name found without override. 391 392 Example: 393 interactor.add_function( 394 my_tool, 395 override=True, 396 disabled=False, 397 schema_extensions={ 398 "param1": {"minimum": 0, "maximum": 100}, 399 "param2": {"format": "email"} 400 } 401 ) 402 """ 403 def _python_type_to_schema(ptype: Any) -> dict: 404 """Convert a Python type annotation to OpenAI-compatible JSON Schema.""" 405 # Handle None case 406 if ptype is None: 407 return {"type": "null"} 408 409 # Get the origin and arguments of the type 410 origin = get_origin(ptype) 411 args = get_args(ptype) 412 413 # Handle Union types (including Optional) 414 if origin is Union: 415 # Check for Optional (Union with None) 416 none_type = type(None) 417 if none_type in args: 418 non_none = [a for a in args if a is not none_type] 419 if len(non_none) == 1: 420 inner = _python_type_to_schema(non_none[0]) 421 inner_copy = inner.copy() 422 inner_copy["nullable"] = True 423 return inner_copy 424 # Multiple types excluding None 425 types = [_python_type_to_schema(a) for a in non_none] 426 return {"anyOf": types, "nullable": True} 427 # Regular Union without None 428 return {"anyOf": [_python_type_to_schema(a) for a in args]} 429 430 # Handle List and similar container types 431 if origin in (list, List): 432 item_type = args[0] if args else Any 433 if item_type is Any: 434 return {"type": "array"} 435 return {"type": "array", "items": _python_type_to_schema(item_type)} 436 437 # Handle Dict types with typing info 438 if origin in (dict, Dict): 439 if not args or len(args) != 2: 440 return {"type": "object"} 441 442 key_type, val_type = args 443 # We can only really use val_type in JSON Schema 444 if val_type is not Any and val_type is not object: 445 return { 446 "type": "object", 447 "additionalProperties": _python_type_to_schema(val_type) 448 } 449 return {"type": "object"} 450 451 # Handle Literal types for enums 452 if origin is Literal: 453 values = args 454 # Try to determine type from values 455 if all(isinstance(v, str) for v in values): 456 return {"type": "string", "enum": list(values)} 457 elif all(isinstance(v, bool) for v in values): 458 return {"type": "boolean", "enum": list(values)} 459 elif all(isinstance(v, (int, float)) for v in values): 460 return {"type": "number", "enum": list(values)} 461 else: 462 # Mixed types, use anyOf 463 return {"anyOf": [{"type": _get_json_type(v), "enum": [v]} for v in values]} 464 465 # Handle basic types 466 if ptype is str: 467 return {"type": "string"} 468 if ptype is int: 469 return {"type": "integer"} 470 if ptype is float: 471 return {"type": "number"} 472 if ptype is bool: 473 return {"type": "boolean"} 474 475 # Handle common datetime types 476 if ptype is datetime: 477 return {"type": "string", "format": "date-time"} 478 if ptype is date: 479 return {"type": "string", "format": "date"} 480 481 # Handle UUID 482 if ptype is uuid.UUID: 483 return {"type": "string", "format": "uuid"} 484 485 # Default to object for any other types 486 return {"type": "object"} 487 488 def _get_json_type(value): 489 """Get the JSON Schema type name for a Python value. 490 491 This helper function maps Python types to their corresponding 492 JSON Schema type names. It handles basic types and provides 493 sensible defaults for complex types. 494 495 Args: 496 value: The Python value to get the JSON type for 497 498 Returns: 499 str: The JSON Schema type name ('string', 'number', 'boolean', 500 'array', 'object', or 'object' as default) 501 """ 502 if isinstance(value, str): 503 return "string" 504 elif isinstance(value, bool): 505 return "boolean" 506 elif isinstance(value, int) or isinstance(value, float): 507 return "number" 508 elif isinstance(value, list): 509 return "array" 510 elif isinstance(value, dict): 511 return "object" 512 else: 513 return "object" # Default 514 515 def _parse_param_docs(docstring: str) -> dict: 516 """Extract parameter descriptions from a docstring.""" 517 if not docstring: 518 return {} 519 520 lines = docstring.splitlines() 521 param_docs = {} 522 current_param = None 523 in_params = False 524 525 # Regular expressions for finding parameter sections and param lines 526 param_section_re = re.compile(r"^(Args|Parameters):\s*$") 527 param_line_re = re.compile(r"^\s{4}(\w+)\s*(?:\([^\)]*\))?:\s*(.*)") 528 529 for line in lines: 530 # Check if we're entering the parameters section 531 if param_section_re.match(line.strip()): 532 in_params = True 533 continue 534 535 if in_params: 536 # Skip empty lines 537 if not line.strip(): 538 continue 539 540 # Check for a parameter definition line 541 match = param_line_re.match(line) 542 if match: 543 current_param = match.group(1) 544 param_docs[current_param] = match.group(2).strip() 545 # Check for continuation of a parameter description 546 elif current_param and line.startswith(" " * 8): 547 param_docs[current_param] += " " + line.strip() 548 # If we see a line that doesn't match our patterns, we're out of the params section 549 else: 550 current_param = None 551 552 return param_docs 553 554 # Start of main function logic 555 556 # Skip if tools are disabled 557 if not self.tools_enabled: 558 return 559 560 # Validate input callable 561 if not external_callable: 562 raise ValueError("A valid external callable must be provided.") 563 564 # Set function name, either from parameter or from callable's __name__ 565 function_name = name or external_callable.__name__ 566 567 # Try to get docstring and extract description 568 try: 569 docstring = inspect.getdoc(external_callable) 570 description = description or (docstring.split("\n")[0].strip() if docstring else "No description provided.") 571 except Exception as e: 572 self._log(f"[TOOL] Warning: Could not extract docstring from {function_name}: {e}", level="warning") 573 docstring = "" 574 description = description or "No description provided." 575 576 # Extract parameter documentation from docstring 577 param_docs = _parse_param_docs(docstring) 578 579 # Handle conflicts with existing functions 580 if override: 581 self.delete_function(function_name) 582 elif any(t["function"]["name"] == function_name for t in self.tools): 583 raise ValueError(f"Function '{function_name}' is already registered. Use override=True to replace.") 584 585 # Try to get function signature for parameter info 586 try: 587 signature = inspect.signature(external_callable) 588 except (ValueError, TypeError) as e: 589 raise ValueError(f"Cannot inspect callable '{function_name}': {e}") 590 591 # Process parameters to build schema 592 properties = {} 593 required = [] 594 595 for param_name, param in signature.parameters.items(): 596 # Skip self, cls parameters for instance/class methods 597 if param_name in ("self", "cls") and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: 598 continue 599 600 # Get parameter annotation, defaulting to Any 601 annotation = param.annotation if param.annotation != inspect.Parameter.empty else Any 602 603 try: 604 # Convert Python type to JSON Schema 605 schema = _python_type_to_schema(annotation) 606 607 # Add description from docstring or create a default one 608 schema["description"] = param_docs.get(param_name, f"{param_name} parameter") 609 610 # Add to properties 611 properties[param_name] = schema 612 613 # If no default value is provided, parameter is required 614 if param.default == inspect.Parameter.empty: 615 required.append(param_name) 616 self._log(f"[TOOL] Parameter '{param_name}' is required", level="debug") 617 else: 618 self._log(f"[TOOL] Parameter '{param_name}' has default value: {param.default}", level="debug") 619 620 except Exception as e: 621 self._log(f"[TOOL] Error processing parameter {param_name} for {function_name}: {e}", level="error") 622 # Add a basic object schema as fallback 623 properties[param_name] = { 624 "type": "string", # Default to string instead of object for better compatibility 625 "description": f"{param_name} parameter (type conversion failed)" 626 } 627 628 # For parameters with no default value, mark as required even if processing failed 629 if param.default == inspect.Parameter.empty: 630 required.append(param_name) 631 self._log(f"[TOOL] Parameter '{param_name}' marked as required despite conversion failure", level="debug") 632 633 # Apply schema extensions if provided 634 if schema_extensions: 635 for param_name, extensions in schema_extensions.items(): 636 if param_name in properties: 637 properties[param_name].update(extensions) 638 639 # Create parameters object with proper placement of 'required' field 640 parameters = { 641 "type": "object", 642 "properties": properties, 643 } 644 645 # Only add required field if there are required parameters 646 if required: 647 parameters["required"] = required 648 649 # Build the final tool specification 650 tool_spec = { 651 "type": "function", 652 "function": { 653 "name": function_name, 654 "description": description, 655 "parameters": parameters 656 } 657 } 658 659 # Set disabled flag if requested 660 if disabled: 661 tool_spec["function"]["disabled"] = True 662 663 # Add to tools list 664 self.tools.append(tool_spec) 665 666 # Make the function available as an attribute on the instance 667 setattr(self, function_name, external_callable) 668 669 # Log the registration with detailed information 670 self._log(f"[TOOL] Registered function '{function_name}' with {len(properties)} parameters", level="info") 671 if required: 672 self._log(f"[TOOL] Required parameters: {required}", level="info") 673 674 return function_name # Return the name for reference 675 676 677 def disable_function(self, name: str) -> bool: 678 """ 679 Disable a registered tool function by name. 680 681 This marks the function as inactive for tool calling without removing it from the internal registry. 682 The function remains visible in the tool listing but is skipped during tool selection by the LLM. 683 684 Args: 685 name (str): The name of the function to disable. 686 687 Returns: 688 bool: True if the function was found and disabled, False otherwise. 689 690 Example: 691 interactor.disable_function("extract_text") 692 """ 693 for tool in self.tools: 694 if tool["function"]["name"] == name: 695 tool["function"]["disabled"] = True 696 return True 697 return False 698 699 700 def enable_function(self, name: str) -> bool: 701 """ 702 Re-enable a previously disabled tool function by name. 703 704 This removes the 'disabled' flag from a tool function, making it available again for LLM use. 705 706 Args: 707 name (str): The name of the function to enable. 708 709 Returns: 710 bool: True if the function was found and enabled, False otherwise. 711 712 Example: 713 interactor.enable_function("extract_text") 714 """ 715 for tool in self.tools: 716 if tool["function"]["name"] == name: 717 tool["function"].pop("disabled", None) 718 return True 719 return False 720 721 722 def delete_function(self, name: str) -> bool: 723 """ 724 Permanently remove a registered tool function from the Interactor. 725 726 This deletes both the tool metadata and the callable attribute, making it fully inaccessible 727 from the active session. Useful for dynamically trimming the toolset. 728 729 Args: 730 name (str): The name of the function to delete. 731 732 Returns: 733 bool: True if the function was found and removed, False otherwise. 734 735 Example: 736 interactor.delete_function("extract_text") 737 """ 738 before = len(self.tools) 739 self.tools = [tool for tool in self.tools if tool["function"]["name"] != name] 740 if hasattr(self, name): 741 delattr(self, name) 742 return len(self.tools) < before 743 744 745 def list_functions(self) -> List[Dict[str, Any]]: 746 """Get the list of registered functions for tool calling. 747 748 Returns: 749 List[Dict[str, Any]]: List of registered functions. 750 """ 751 return self.tools 752 753 754 def list_models( 755 self, 756 providers: Optional[Union[str, List[str]]] = None, 757 filter: Optional[str] = None 758 ) -> List[str]: 759 """Retrieve available models from configured providers. 760 761 Args: 762 providers: Provider name or list of provider names. If None, all are queried. 763 filter: Optional regex to filter model names. 764 765 Returns: 766 List[str]: Sorted list of "provider:model_id" strings. 767 """ 768 models = [] 769 770 if providers is None: 771 providers_to_list = self.providers 772 elif isinstance(providers, str): 773 providers_to_list = {providers: self.providers.get(providers)} 774 elif isinstance(providers, list): 775 providers_to_list = {p: self.providers.get(p) for p in providers} 776 else: 777 return [] 778 779 invalid_providers = [p for p in providers_to_list if p not in self.providers or self.providers[p] is None] 780 if invalid_providers: 781 self.logger.error(f"Invalid providers: {invalid_providers}") 782 return [] 783 784 regex_pattern = None 785 if filter: 786 try: 787 regex_pattern = re.compile(filter, re.IGNORECASE) 788 except re.error as e: 789 self.logger.error(f"Invalid regex pattern: {e}") 790 return [] 791 792 for provider_name, config in providers_to_list.items(): 793 sdk = config.get("sdk", "openai") 794 base_url = config.get("base_url") 795 api_key = config.get("api_key") 796 797 try: 798 if sdk == "openai": 799 client = openai.OpenAI(api_key=api_key, base_url=base_url) 800 response = client.models.list() 801 for model in response.data: 802 model_id = f"{provider_name}:{model.id}" 803 if not regex_pattern or regex_pattern.search(model_id): 804 models.append(model_id) 805 806 elif sdk == "anthropic": 807 client = Anthropic(api_key=api_key) 808 response = client.models.list() 809 for model in response: 810 model_id = f"{provider_name}:{model.id}" 811 if not regex_pattern or regex_pattern.search(model_id): 812 models.append(model_id) 813 else: 814 self.logger.warning(f"SDK '{sdk}' for provider '{provider_name}' is not supported by list_models()") 815 816 except Exception as e: 817 self.logger.error(f"Failed to list models for {provider_name}: {e}") 818 819 return sorted(models, key=str.lower) 820 821 822 async def _retry_with_backoff(self, func: Callable, *args, **kwargs): 823 """Execute a function with exponential backoff retry logic. 824 825 This method implements a robust retry mechanism for API calls with 826 exponential backoff. It handles rate limits, connection errors, and 827 other transient failures. If all retries fail, it will attempt to 828 switch to a fallback model if configured. 829 830 Args: 831 func (Callable): The async function to execute 832 *args: Positional arguments to pass to the function 833 **kwargs: Keyword arguments to pass to the function 834 835 Returns: 836 The result of the function call if successful 837 838 Raises: 839 Exception: If all retries fail and no fallback model is available 840 """ 841 for attempt in range(self.max_retries + 1): 842 try: 843 return await func(*args, **kwargs) 844 845 except (RateLimitError, APIConnectionError, aiohttp.ClientError) as e: 846 if attempt == self.max_retries: 847 model_key = f"{self.provider}:{self.model}" 848 if self.fallback_model and model_key != self.fallback_model: 849 print(f"[yellow]Model '{model_key}' failed. Switching to fallback: {self.fallback_model}[/yellow]") 850 self._setup_client(self.fallback_model) 851 self._setup_encoding() 852 self._normalizer() 853 return await func(*args, **kwargs) # retry once with fallback model 854 else: 855 self.logger.error(f"All {self.max_retries} retries failed: {e}") 856 raise 857 858 delay = self.retry_delay * (2 ** attempt) 859 self.logger.warning(f"Retry {attempt + 1}/{self.max_retries} after {delay}s due to {e}") 860 self._log(f"[RETRY] Attempt {attempt + 1}/{self.max_retries} failed: {e}", level="warning") 861 await asyncio.sleep(delay) 862 863 except OpenAIError as e: 864 self.logger.error(f"OpenAI error: {e}") 865 raise 866 867 except Exception as e: 868 self.logger.error(f"Unexpected error: {e}") 869 raise 870 871 872 def interact( 873 self, 874 user_input: Optional[str], 875 quiet: bool = False, 876 tools: bool = True, 877 stream: bool = True, 878 markdown: bool = False, 879 model: Optional[str] = None, 880 output_callback: Optional[Callable[[str], None]] = None, 881 session_id: Optional[str] = None, 882 raw: Optional[bool] = None, 883 tool_suppress: bool = True, 884 timeout: float = 60.0 885 ) -> Union[Optional[str], "TokenStream"]: 886 """Main universal gateway for all LLM interaction. 887 888 This function serves as the single entry point for all interactions with the language model. 889 When `raw=False` (default), it handles the interaction internally and returns the full response. 890 When `raw=True`, it returns a context manager that yields chunks of the response for custom handling. 891 892 Args: 893 user_input: Text input from the user. 894 quiet: If True, don't print status info or progress. 895 tools: Enable (True) or disable (False) tool calling. 896 stream: Enable (True) or disable (False) streaming responses. 897 markdown: If True, renders content as markdown. 898 model: Optional model override. 899 output_callback: Optional callback to handle the output. 900 session_id: Optional session ID to load messages from. 901 raw: If True, return a context manager instead of handling the interaction internally. 902 If None, use the class-level setting from __init__. 903 tool_suppress: If True and raw=True, filter out tool-related status messages. 904 timeout: Maximum time in seconds to wait for the stream to complete when raw=True. 905 906 Returns: 907 If raw=False: The complete response from the model as a string, or None if there was an error. 908 If raw=True: A context manager that yields chunks of the response as they arrive. 909 910 Example with default mode: 911 response = ai.interact("Tell me a joke") 912 913 Example with raw mode: 914 with ai.interact("Tell me a joke", raw=True) as stream: 915 for chunk in stream: 916 print(chunk, end="", flush=True) 917 """ 918 if not user_input: 919 return None 920 921 if quiet or self.quiet: 922 markdown = False 923 stream = False 924 925 # Determine if we should use raw mode 926 # If raw parameter is explicitly provided, use that; otherwise use class setting 927 use_raw = self.raw if raw is None else raw 928 929 # If raw mode is requested, delegate to interact_raw 930 if use_raw: 931 return self._interact_raw( 932 user_input=user_input, 933 tools=tools, 934 model=model, 935 session_id=session_id, 936 tool_suppress=tool_suppress, 937 timeout=timeout 938 ) 939 940 # Setup model if specified 941 if model: 942 self._setup_client(model) 943 self._setup_encoding() 944 945 # Session handling 946 if self.session_enabled and session_id: 947 self.session_id = session_id 948 self.session_load(session_id) 949 950 # Add user message using messages_add 951 self.messages_add(role="user", content=user_input) 952 953 # Log token count estimate 954 token_count = self._count_tokens(self.history) 955 if not quiet: 956 print(f"[dim]Estimated tokens in context: {token_count} / {self.context_length}[/dim]") 957 958 # Make sure we have enough context space 959 if token_count > self.context_length: 960 if self._cycle_messages(): 961 if not quiet: 962 print("[red]Context window exceeded. Cannot proceed.[/red]") 963 return None 964 965 # Log user input 966 self._log(f"[USER] {user_input}") 967 968 # Handle the actual interaction with complete streaming for all responses 969 result = asyncio.run(self._interact_async_core( 970 user_input=user_input, 971 quiet=quiet, 972 tools=tools, 973 stream=stream, 974 markdown=markdown, 975 output_callback=output_callback 976 )) 977 978 # Log completion for this interaction 979 self._log(f"[INTERACTION] Completed with {len(self.history)} total messages") 980 981 return result 982 983 984 def _interact_raw( 985 self, 986 user_input: Optional[str], 987 tools: bool = True, 988 model: Optional[str] = None, 989 session_id: Optional[str] = None, 990 tool_suppress: bool = True, 991 timeout: float = 60.0 992 ): 993 """ 994 Low-level function that returns a raw stream of tokens from the model. 995 996 This method works as a context manager that yields a generator of streaming tokens. 997 The caller is responsible for handling the output stream. Typically, this is used 998 indirectly through interact() with raw=True. 999 1000 Args: 1001 user_input: Text input from the user. 1002 tools: Enable (True) or disable (False) tool calling. 1003 model: Optional model override. 1004 session_id: Optional session ID to load messages from. 1005 tool_suppress: If True, filter out tool-related status messages. 1006 timeout: Maximum time in seconds to wait for the stream to complete. 1007 1008 Returns: 1009 A context manager that yields a stream of tokens. 1010 1011 Example: 1012 with ai.interact_raw("Hello world") as stream: 1013 for chunk in stream: 1014 print(chunk, end="", flush=True) 1015 """ 1016 if not user_input: 1017 return None 1018 1019 # Setup model if specified 1020 if model: 1021 self._setup_client(model) 1022 self._setup_encoding() 1023 1024 # Session handling 1025 if self.session_enabled and session_id: 1026 self.session_id = session_id 1027 self.session_load(session_id) 1028 1029 # Add user message 1030 self.messages_add(role="user", content=user_input) 1031 1032 # Log token count estimate 1033 token_count = self._count_tokens(self.history) 1034 self._log(f"[STREAM] Estimated tokens in context: {token_count} / {self.context_length}") 1035 1036 # Make sure we have enough context space 1037 if token_count > self.context_length: 1038 if self._cycle_messages(): 1039 self._log("[STREAM] Context window exceeded. Cannot proceed.", level="error") 1040 return None 1041 1042 # Log user input 1043 self._log(f"[USER] {user_input}") 1044 1045 # Create a token stream class using a thread-safe queue 1046 class TokenStream: 1047 def __init__(self, interactor, user_input, tools, tool_suppress, timeout): 1048 """Initialize a new TokenStream instance. 1049 1050 This class provides a context manager for streaming token responses 1051 from the AI model. It handles asynchronous token delivery, tool call 1052 suppression, and timeout management. 1053 1054 Args: 1055 interactor: The parent Interactor instance 1056 user_input: The user's input text 1057 tools: Whether tool calling is enabled 1058 tool_suppress: Whether to suppress tool-related status messages 1059 timeout: Maximum time in seconds to wait for stream completion 1060 """ 1061 self.interactor = interactor 1062 self.user_input = user_input 1063 self.tools = tools 1064 self.tool_suppress = tool_suppress 1065 self.timeout = timeout 1066 self.token_queue = queue.Queue() 1067 self.thread = None 1068 self.result = None 1069 self.error = None 1070 self.completed = False 1071 1072 def __enter__(self): 1073 """Enter the context manager and start the streaming process. 1074 1075 This method initializes the streaming worker thread and returns 1076 self for iteration. The worker thread handles the actual API 1077 communication and token delivery. 1078 1079 Returns: 1080 TokenStream: Self for iteration 1081 """ 1082 # Start the thread for async interaction 1083 def stream_worker(): 1084 """Worker thread that handles the streaming interaction. 1085 1086 This internal function runs in a separate thread to handle 1087 the asynchronous API communication and token delivery. 1088 """ 1089 # Define output callback to put tokens in queue 1090 def callback(text): 1091 """Process and queue incoming text tokens. 1092 1093 This internal function handles incoming text chunks, 1094 optionally filtering tool-related messages, and adds 1095 them to the token queue. 1096 1097 Args: 1098 text: The text chunk to process and queue 1099 """ 1100 # Filter out tool messages if requested 1101 if self.tool_suppress: 1102 try: 1103 # Check if this is a tool status message (JSON format) 1104 data = json.loads(text) 1105 if isinstance(data, dict) and data.get("type") == "tool_call": 1106 # Skip this message 1107 return 1108 except (json.JSONDecodeError, TypeError): 1109 # Not JSON or not a dict, continue normally 1110 pass 1111 1112 # Add to queue 1113 self.token_queue.put(text) 1114 1115 # Run the interaction in a new event loop 1116 loop = asyncio.new_event_loop() 1117 asyncio.set_event_loop(loop) 1118 1119 try: 1120 # Run the interaction 1121 self.result = loop.run_until_complete( 1122 self.interactor._interact_async_core( 1123 user_input=self.user_input, 1124 quiet=True, 1125 tools=self.tools, 1126 stream=True, 1127 markdown=False, 1128 output_callback=callback 1129 ) 1130 ) 1131 # Signal successful completion 1132 self.completed = True 1133 except Exception as e: 1134 self.error = str(e) 1135 self.interactor.logger.error(f"Streaming error: {traceback.format_exc()}") 1136 # Add error information to the queue if we haven't yielded anything yet 1137 if self.token_queue.empty(): 1138 self.token_queue.put(f"Error: {str(e)}") 1139 finally: 1140 # Signal end of stream regardless of success/failure 1141 self.token_queue.put(None) 1142 loop.close() 1143 1144 # Start the worker thread 1145 self.thread = threading.Thread(target=stream_worker) 1146 self.thread.daemon = True 1147 self.thread.start() 1148 1149 # Return self for iteration 1150 return self 1151 1152 def __iter__(self): 1153 """Return self as an iterator. 1154 1155 Returns: 1156 TokenStream: Self for iteration 1157 """ 1158 return self 1159 1160 def __next__(self): 1161 """Get the next token from the stream. 1162 1163 This method implements the iterator protocol, retrieving the next 1164 token from the queue with timeout handling. 1165 1166 Returns: 1167 str: The next token from the stream 1168 1169 Raises: 1170 StopIteration: When the stream is complete or times out 1171 """ 1172 # Get next token from queue with timeout to prevent hanging 1173 try: 1174 token = self.token_queue.get(timeout=self.timeout) 1175 if token is None: 1176 # End of stream 1177 raise StopIteration 1178 return token 1179 except queue.Empty: 1180 # Timeout reached 1181 self.interactor.logger.warning(f"Stream timeout after {self.timeout}s") 1182 if not self.completed and not self.error: 1183 # Clean up the thread - it might be stuck 1184 if self.thread and self.thread.is_alive(): 1185 # We can't forcibly terminate a thread in Python, 1186 # but we can report the issue 1187 self.interactor.logger.error("Stream worker thread is hung") 1188 raise StopIteration 1189 1190 def __exit__(self, exc_type, exc_val, exc_tb): 1191 """Exit the context manager and clean up resources. 1192 1193 This method handles cleanup when the context manager is exited, 1194 including thread cleanup and message history updates. 1195 1196 Args: 1197 exc_type: The exception type if an exception was raised 1198 exc_val: The exception value if an exception was raised 1199 exc_tb: The exception traceback if an exception was raised 1200 1201 Returns: 1202 bool: False to not suppress any exceptions 1203 """ 1204 # Clean up resources 1205 if self.thread and self.thread.is_alive(): 1206 self.thread.join(timeout=2.0) 1207 1208 # Add messages to history if successful 1209 if self.completed and self.result and not exc_type: 1210 if isinstance(self.result, str) and self.result != "No response.": 1211 # If we had a successful completion, ensure the result is in the history 1212 last_msg = self.interactor.history[-1] if self.interactor.history else None 1213 if not last_msg or last_msg.get("role") != "assistant" or last_msg.get("content") != self.result: 1214 # Add a clean assistant message to history if not already there 1215 self.interactor.messages_add(role="assistant", content=self.result) 1216 1217 # If there was an error in the stream processing, log it 1218 if self.error: 1219 self.interactor.logger.error(f"Stream processing error: {self.error}") 1220 1221 return False # Don't suppress exceptions 1222 1223 return TokenStream(self, user_input, tools, tool_suppress, timeout) 1224 1225 1226 async def _interact_async_core( 1227 self, 1228 user_input: str, 1229 quiet: bool = False, 1230 tools: bool = True, 1231 stream: bool = True, 1232 markdown: bool = False, 1233 output_callback: Optional[Callable] = None 1234 ) -> str: 1235 """Main SDK-agnostic async execution pipeline with tool call looping support.""" 1236 # Prepare display handler 1237 live = Live(console=console, refresh_per_second=100) if markdown and stream else None 1238 if live: 1239 live.start() 1240 1241 # Initialize variables for iteration tracking 1242 full_content = "" 1243 max_iterations = 5 # Prevent infinite loops 1244 iterations = 0 1245 1246 # Main interaction loop - continues until no more tool calls or max iterations reached 1247 while iterations < max_iterations: 1248 iterations += 1 1249 1250 try: 1251 # Execute the appropriate SDK runner - history is already normalized 1252 response_data = await self.sdk_runner( 1253 model=self.model, 1254 messages=self.history, 1255 stream=stream, 1256 markdown=markdown, 1257 quiet=quiet if iterations == 1 else False, 1258 live=live, 1259 output_callback=output_callback 1260 ) 1261 1262 # Extract response data 1263 content = response_data.get("content", "") 1264 tool_calls = response_data.get("tool_calls", []) 1265 1266 # Log the response data for debugging 1267 self._log(f"[ITERATION {iterations}] Content: {len(content)} chars, Tool calls: {len(tool_calls)}") 1268 1269 # Add content to full response 1270 if iterations == 1: 1271 full_content = content 1272 elif content: 1273 if full_content and content: 1274 full_content += f"\n{content}" 1275 else: 1276 full_content = content 1277 1278 # Add assistant message with or without tool calls 1279 if tool_calls: 1280 # Process each tool call 1281 for call in tool_calls: 1282 # Add assistant message with tool call 1283 tool_info = { 1284 "id": call["id"], 1285 "name": call["function"]["name"], 1286 "arguments": call["function"]["arguments"] 1287 } 1288 1289 # Add the assistant message with tool call 1290 self.messages_add( 1291 role="assistant", 1292 content=content if len(tool_calls) == 1 else "", 1293 tool_info=tool_info 1294 ) 1295 1296 # Execute the tool 1297 call_name = call["function"]["name"] 1298 call_args = call["function"]["arguments"] 1299 call_id = call["id"] 1300 1301 # Stop Rich Live while executing tool calls 1302 live_was_active = True 1303 if live and live.is_started: 1304 live_was_active = True 1305 live.stop() 1306 1307 result = await self._handle_tool_call_async( 1308 function_name=call_name, 1309 function_arguments=call_args, 1310 tool_call_id=call_id, 1311 quiet=quiet, 1312 safe=False, 1313 output_callback=output_callback 1314 ) 1315 1316 # Restart live display if it was active before 1317 if live_was_active and live: 1318 live.start() 1319 1320 # Add tool result message 1321 tool_result_info = { 1322 "id": call_id, 1323 "result": result 1324 } 1325 1326 self.messages_add( 1327 role="tool", 1328 content=result, 1329 tool_info=tool_result_info 1330 ) 1331 else: 1332 # Simple assistant response without tool calls 1333 self.messages_add(role="assistant", content=content) 1334 break # No more tools to process, we're done 1335 1336 # Reset live display if needed 1337 if stream and live: 1338 live.stop() 1339 live = Live(console=console, refresh_per_second=100) 1340 live.start() 1341 1342 except Exception as e: 1343 self.logger.error(f"[{self.sdk.upper()} ERROR] {str(e)}") 1344 self._log(f"[ERROR] Error in interaction loop: {str(e)}", level="error") 1345 if live: 1346 live.stop() 1347 return f"Error: {str(e)}" 1348 1349 # Clean up display 1350 if live: 1351 live.stop() 1352 1353 return full_content or None 1354 1355 1356 async def _openai_runner( 1357 self, 1358 *, 1359 model, 1360 messages, 1361 stream, 1362 markdown=False, 1363 quiet=False, 1364 live=None, 1365 output_callback=None 1366 ): 1367 """Handle OpenAI-specific API interactions and response processing.""" 1368 # Log what we're sending for debugging 1369 self._log(f"[OPENAI REQUEST] Sending request to {model} with {len(self.history)} messages", level="debug") 1370 1371 # Prepare API parameters - history is already normalized by _normalizer 1372 params = { 1373 "model": model, 1374 "messages": self.history, 1375 "stream": stream, 1376 } 1377 1378 # Add tools if enabled 1379 if self.tools_enabled and self.tools_supported: 1380 enabled_tools = self._get_enabled_tools() 1381 if enabled_tools: 1382 params["tools"] = enabled_tools 1383 params["tool_choice"] = "auto" 1384 1385 # Call API with retry handling 1386 try: 1387 response = await self._retry_with_backoff( 1388 self.async_client.chat.completions.create, 1389 **params 1390 ) 1391 except Exception: 1392 self.logger.error(f"[OPENAI ERROR RUNNER]: {traceback.format_exc()}") 1393 raise 1394 1395 assistant_content = "" 1396 tool_calls_dict = {} 1397 1398 # Process streaming response 1399 if stream and hasattr(response, "__aiter__"): 1400 async for chunk in response: 1401 delta = getattr(chunk.choices[0], "delta", None) 1402 1403 # Handle content chunks 1404 if hasattr(delta, "content") and delta.content is not None: 1405 text = delta.content 1406 assistant_content += text 1407 if output_callback: 1408 output_callback(text) 1409 elif live: 1410 live.update(Markdown(assistant_content)) 1411 elif not markdown: 1412 print(text, end="") 1413 1414 # Process tool calls 1415 if hasattr(delta, "tool_calls") and delta.tool_calls: 1416 for tool_call_delta in delta.tool_calls: 1417 index = tool_call_delta.index 1418 if index not in tool_calls_dict: 1419 tool_calls_dict[index] = { 1420 "id": tool_call_delta.id if hasattr(tool_call_delta, "id") else None, 1421 "function": {"name": "", "arguments": ""} 1422 } 1423 1424 function = getattr(tool_call_delta, "function", None) 1425 if function: 1426 name = getattr(function, "name", None) 1427 args = getattr(function, "arguments", "") 1428 if name: 1429 tool_calls_dict[index]["function"]["name"] = name 1430 if args: 1431 tool_calls_dict[index]["function"]["arguments"] += args 1432 if tool_call_delta.id and not tool_calls_dict[index]["id"]: 1433 tool_calls_dict[index]["id"] = tool_call_delta.id 1434 1435 # Make sure the ID is set regardless 1436 if hasattr(tool_call_delta, "id") and tool_call_delta.id and not tool_calls_dict[index]["id"]: 1437 tool_calls_dict[index]["id"] = tool_call_delta.id 1438 1439 if not output_callback and not markdown and not quiet: 1440 print() 1441 1442 # Process non-streaming response 1443 else: 1444 message = response.choices[0].message 1445 assistant_content = message.content or "" 1446 1447 if hasattr(message, "tool_calls") and message.tool_calls: 1448 for i, tool_call in enumerate(message.tool_calls): 1449 tool_calls_dict[i] = { 1450 "id": tool_call.id, 1451 "function": { 1452 "name": tool_call.function.name, 1453 "arguments": tool_call.function.arguments 1454 } 1455 } 1456 1457 if output_callback: 1458 output_callback(assistant_content) 1459 elif not quiet: 1460 print(assistant_content) 1461 1462 # Log tool calls for debugging 1463 if tool_calls_dict: 1464 self._log(f"[OPENAI TOOL CALLS] Found {len(tool_calls_dict)} tool calls", level="debug") 1465 for idx, call in tool_calls_dict.items(): 1466 self._log(f"[OPENAI TOOL CALL {idx}] {call['function']['name']} with ID {call['id']}", level="debug") 1467 1468 # Return standardized response format 1469 return { 1470 "content": assistant_content, 1471 "tool_calls": list(tool_calls_dict.values()) 1472 } 1473 1474 1475 async def _anthropic_runner( 1476 self, 1477 *, 1478 model, 1479 messages, 1480 stream, 1481 markdown=False, 1482 quiet=False, 1483 live=None, 1484 output_callback=None 1485 ): 1486 """Handle Anthropic-specific API interactions and response processing.""" 1487 # Log what we're sending for debugging 1488 self._log(f"[ANTHROPIC REQUEST] Sending request to {model} with {len(self.history)} messages", level="debug") 1489 1490 # Prepare API parameters - history is already normalized by _normalizer 1491 params = { 1492 "model": model, 1493 "messages": self.history, 1494 "max_tokens": 8192, 1495 "system": self.system 1496 } 1497 1498 # Add tools support if needed 1499 if self.tools_enabled and self.tools_supported: 1500 enabled_tools = [] 1501 for tool in self._get_enabled_tools(): 1502 # Extract parameters from OpenAI format 1503 tool_params = tool["function"]["parameters"] 1504 1505 # Create Anthropic-compatible tool definition 1506 format_tool = { 1507 "name": tool["function"]["name"], 1508 "description": tool["function"].get("description", ""), 1509 "input_schema": { 1510 "type": "object", 1511 "properties": tool_params.get("properties", {}) 1512 } 1513 } 1514 1515 # Ensure 'required' is at the correct level for Anthropic (as a direct child of input_schema) 1516 if "required" in tool_params: 1517 format_tool["input_schema"]["required"] = tool_params["required"] 1518 1519 enabled_tools.append(format_tool) 1520 1521 params["tools"] = enabled_tools 1522 1523 assistant_content = "" 1524 tool_calls_dict = {} 1525 1526 try: 1527 # Process streaming response 1528 if stream: 1529 stream_params = params.copy() 1530 stream_params["stream"] = True 1531 1532 stream_response = await self._retry_with_backoff( 1533 self.async_client.messages.create, 1534 **stream_params 1535 ) 1536 1537 content_type = None 1538 async for chunk in stream_response: 1539 chunk_type = getattr(chunk, "type", "unknown") 1540 self._log(f"[ANTHROPIC CHUNK] Type: {chunk_type}", level="debug") 1541 if chunk_type == "content_block_start" and hasattr(chunk.content_block, "type"): 1542 content_type = chunk.content_block.type 1543 if content_type == "tool_use": 1544 tool_id = chunk.content_block.id 1545 tool_name = chunk.content_block.name 1546 tool_input = chunk.content_block.input 1547 tool_calls_dict[tool_id] = { 1548 "id": tool_id, 1549 "function": { 1550 "name": tool_name, 1551 "arguments": "" 1552 } 1553 } 1554 self._log(f"[ANTHROPIC TOOL USE] {tool_name}", level="debug") 1555 1556 # Handle text content 1557 if chunk_type == "content_block_delta" and hasattr(chunk.delta, "text"): 1558 delta = chunk.delta.text 1559 assistant_content += delta 1560 if output_callback: 1561 output_callback(delta) 1562 elif live: 1563 live.update(Markdown(assistant_content)) 1564 elif not markdown: 1565 print(delta, end="") 1566 1567 # Handle complete tool use 1568 elif chunk_type == "content_block_delta" and content_type == "tool_use": 1569 tool_calls_dict[tool_id]["function"]["arguments"] += chunk.delta.partial_json 1570 1571 # Process non-streaming response 1572 else: 1573 # For non-streaming, ensure we don't send the stream parameter 1574 non_stream_params = params.copy() 1575 non_stream_params.pop("stream", None) # Remove stream if it exists 1576 1577 response = await self._retry_with_backoff( 1578 self.async_client.messages.create, 1579 **non_stream_params 1580 ) 1581 1582 # Extract text content 1583 for content_block in response.content: 1584 if content_block.type == "text": 1585 assistant_content += content_block.text 1586 1587 if content_block.type == "tool_use": 1588 tool_id = content_block.id 1589 tool_name = content_block.name 1590 tool_input = content_block.input 1591 tool_calls_dict[tool_id] = { 1592 "id": tool_id, 1593 "function": { 1594 "name": tool_name, 1595 "arguments": tool_input 1596 } 1597 } 1598 self._log(f"[ANTHROPIC TOOL USE] {tool_name}", level="debug") 1599 1600 if output_callback: 1601 output_callback(assistant_content) 1602 elif not quiet: 1603 print(assistant_content) 1604 1605 except Exception as e: 1606 self._log(f"[ANTHROPIC ERROR RUNNER] {traceback.format_exc()}", level="error") 1607 1608 # Return something usable even in case of error 1609 return { 1610 "content": f"Error processing Anthropic response: {str(e)}", 1611 "tool_calls": [] 1612 } 1613 1614 # Return standardized response format 1615 return { 1616 "content": assistant_content, 1617 "tool_calls": list(tool_calls_dict.values()) 1618 } 1619 1620 1621 def _get_enabled_tools(self) -> List[dict]: 1622 """Return the list of currently enabled tool function definitions.""" 1623 return [ 1624 tool for tool in self.tools 1625 if not tool["function"].get("disabled", False) 1626 ] 1627 1628 1629 async def _handle_tool_call_async( 1630 self, 1631 function_name: str, 1632 function_arguments: str, 1633 tool_call_id: str, 1634 quiet: bool = False, 1635 safe: bool = False, 1636 output_callback: Optional[Callable[[str], None]] = None 1637 ) -> str: 1638 """Process a tool call asynchronously and return the result. 1639 1640 Args: 1641 function_name: Name of the function to call. 1642 function_arguments: JSON string containing the function arguments. 1643 tool_call_id: Unique identifier for this tool call. 1644 params: Parameters used for the original API call. 1645 safe: If True, prompts for confirmation before executing the tool call. 1646 output_callback: Optional callback to handle the tool call result. 1647 1648 Returns: 1649 The result of the function call. 1650 1651 Raises: 1652 ValueError: If the function is not found or JSON is invalid. 1653 """ 1654 if isinstance(function_arguments, str): 1655 arguments = json.loads(function_arguments) 1656 else: 1657 arguments = function_arguments 1658 1659 self._log(f"[TOOL:{function_name}] args={arguments}") 1660 1661 func = getattr(self, function_name, None) 1662 if not func: 1663 raise ValueError(f"Function '{function_name}' not found.") 1664 1665 be_quiet = self.quiet if quiet is None else quiet 1666 1667 if not be_quiet: 1668 print(f"\nRunning {function_name}...") 1669 1670 if output_callback: 1671 notification = json.dumps({ 1672 "type": "tool_call", 1673 "tool_name": function_name, 1674 "status": "started" 1675 }) 1676 output_callback(notification) 1677 1678 try: 1679 if safe: 1680 prompt = f"[bold yellow]Proposed tool call:[/bold yellow] {function_name}({json.dumps(arguments, indent=2)})\n[bold cyan]Execute? [y/n]: [/bold cyan]" 1681 confirmed = Confirm.ask(prompt, default=False) 1682 if not confirmed: 1683 command_result = { 1684 "status": "cancelled", 1685 "message": "Tool call aborted by user" 1686 } 1687 print("[red]Tool call cancelled by user[/red]") 1688 else: 1689 loop = asyncio.get_event_loop() 1690 command_result = await loop.run_in_executor(None, lambda: func(**arguments)) 1691 else: 1692 loop = asyncio.get_event_loop() 1693 command_result = await loop.run_in_executor(None, lambda: func(**arguments)) 1694 1695 try: 1696 json.dumps(command_result) 1697 except TypeError as e: 1698 self.logger.error(f"Tool call result not serializable: {e}") 1699 return {"error": "Tool call returned unserializable data."} 1700 1701 if output_callback: 1702 notification = json.dumps({ 1703 "type": "tool_call", 1704 "tool_name": function_name, 1705 "status": "completed" 1706 }) 1707 output_callback(notification) 1708 1709 return command_result 1710 1711 except Exception as e: 1712 self._log(f"[ERROR] Tool execution failed: {e}", level="error") 1713 self.logger.error(f"Error executing tool function '{function_name}': {e}") 1714 return {"error": str(e)} 1715 1716 1717 def _setup_encoding(self): 1718 """Initialize the token encoding system for the current model. 1719 1720 This method sets up the appropriate tokenizer based on the current 1721 model provider. For OpenAI models, it attempts to use the model-specific 1722 tokenizer, falling back to cl100k_base if not available. For other 1723 providers, it uses cl100k_base as a default. 1724 1725 The encoding is used for token counting and context management. 1726 """ 1727 try: 1728 if self.provider == "openai": 1729 try: 1730 self.encoding = tiktoken.encoding_for_model(self.model) 1731 self._log(f"[ENCODING] Loaded tokenizer for OpenAI model: {self.model}") 1732 except: 1733 self.encoding = tiktoken.get_encoding("cl100k_base") 1734 self._log(f"[ENCODING] Fallback to cl100k_base for model: {self.model}") 1735 else: 1736 self.encoding = tiktoken.get_encoding("cl100k_base") 1737 self._log(f"[ENCODING] Defaulting to cl100k_base for non-OpenAI model: {self.model}") 1738 except Exception as e: 1739 self.logger.error(f"Failed to setup encoding: {e}") 1740 self.encoding = tiktoken.get_encoding("cl100k_base") 1741 1742 1743 def _estimate_tokens_tiktoken(self, messages) -> int: 1744 """Rough token count estimate using tiktoken for OpenAI or fallback cases.""" 1745 if not hasattr(self, "encoding") or not self.encoding: 1746 self._setup_encoding() 1747 return sum(len(self.encoding.encode(msg.get("content", ""))) for msg in messages if isinstance(msg.get("content"), str)) 1748 1749 1750 def _count_tokens(self, messages, use_cache=True) -> int: 1751 """Accurately estimate token count for messages including tool calls with caching support. 1752 1753 Args: 1754 messages: List of message objects in either OpenAI or Anthropic format. 1755 use_cache: Whether to use and update the token count cache. 1756 1757 Returns: 1758 int: Estimated token count. 1759 """ 1760 # Setup encoding if needed 1761 if not hasattr(self, "encoding") or not self.encoding: 1762 self._setup_encoding() 1763 1764 # Initialize cache if it doesn't exist 1765 if not hasattr(self, "_token_count_cache"): 1766 self._token_count_cache = {} 1767 1768 # Generate a cache key based on message content hashes 1769 if use_cache: 1770 try: 1771 # Create a cache key using message IDs if available, or content hashes 1772 cache_key_parts = [] 1773 for msg in messages: 1774 if isinstance(msg, dict): 1775 # Try to use stable identifiers for cache key 1776 msg_id = msg.get("id", None) 1777 timestamp = msg.get("timestamp", None) 1778 1779 if msg_id and timestamp: 1780 cache_key_parts.append(f"{msg_id}:{timestamp}") 1781 else: 1782 # Fall back to content-based hash if no stable IDs 1783 content_str = str(msg.get("content", "")) 1784 role = msg.get("role", "unknown") 1785 cache_key_parts.append(f"{role}:{hash(content_str)}") 1786 1787 cache_key = ":".join(cache_key_parts) 1788 if cache_key in self._token_count_cache: 1789 return self._token_count_cache[cache_key] 1790 except Exception as e: 1791 # If caching fails, just continue with normal counting 1792 self._log(f"[TOKEN COUNT] Cache key generation failed: {e}", level="debug") 1793 use_cache = False 1794 1795 # For Claude models, try to use their built-in token counter 1796 if self.sdk == "anthropic": 1797 try: 1798 # Convert messages to Anthropic format if needed 1799 anthropic_messages = [] 1800 for msg in messages: 1801 if msg.get("role") == "system": 1802 continue # System handled separately 1803 1804 if msg.get("role") == "tool": 1805 # Skip tool messages in token count to avoid double-counting 1806 continue 1807 1808 if msg.get("role") == "user" and isinstance(msg.get("content"), list): 1809 # Already in Anthropic format with tool_result 1810 anthropic_messages.append(msg) 1811 elif msg.get("role") in ["user", "assistant"]: 1812 if not msg.get("tool_calls") and not msg.get("tool_use"): 1813 # Simple message 1814 anthropic_messages.append({ 1815 "role": msg.get("role"), 1816 "content": msg.get("content", "") 1817 }) 1818 1819 # Use Anthropic's token counter if messages exist 1820 if anthropic_messages: 1821 response = self.client.messages.count_tokens( 1822 model=self.model, 1823 messages=anthropic_messages, 1824 system=self.system 1825 ) 1826 token_count = response.input_tokens 1827 1828 # Cache the result for future use 1829 if use_cache and 'cache_key' in locals(): 1830 self._token_count_cache[cache_key] = token_count 1831 1832 return token_count 1833 except Exception as e: 1834 # Fall back to our estimation 1835 self._log(f"[TOKEN COUNT] Error using Anthropic token counter: {e}", level="debug") 1836 1837 # More accurate token counting for all message types 1838 num_tokens = 0 1839 1840 # Count tokens for each message 1841 for msg in messages: 1842 # Base token count for message metadata (role + message format) 1843 num_tokens += 4 # Message overhead 1844 1845 # Add tokens for role name 1846 role = msg.get("role", "") 1847 num_tokens += len(self.encoding.encode(role)) 1848 1849 # Count tokens in message content 1850 if isinstance(msg.get("content"), str): 1851 content = msg.get("content", "") 1852 content_tokens = len(self.encoding.encode(content)) 1853 num_tokens += content_tokens 1854 1855 elif isinstance(msg.get("content"), list): 1856 # Handle Anthropic-style content lists 1857 for item in msg.get("content", []): 1858 if isinstance(item, dict): 1859 # Tool result or other structured content 1860 if item.get("type") == "tool_result": 1861 result_content = item.get("content", "") 1862 if isinstance(result_content, str): 1863 num_tokens += len(self.encoding.encode(result_content)) 1864 else: 1865 # JSON serialization for dict/list content 1866 num_tokens += len(self.encoding.encode(json.dumps(result_content))) 1867 # Add tokens for tool_use_id and type fields 1868 num_tokens += len(self.encoding.encode(item.get("type", ""))) 1869 num_tokens += len(self.encoding.encode(item.get("tool_use_id", ""))) 1870 1871 # Text content type 1872 elif item.get("type") == "text": 1873 num_tokens += len(self.encoding.encode(item.get("text", ""))) 1874 1875 # Tool use type 1876 elif item.get("type") == "tool_use": 1877 num_tokens += len(self.encoding.encode(item.get("name", ""))) 1878 tool_input = item.get("input", {}) 1879 if isinstance(tool_input, str): 1880 num_tokens += len(self.encoding.encode(tool_input)) 1881 else: 1882 num_tokens += len(self.encoding.encode(json.dumps(tool_input))) 1883 num_tokens += len(self.encoding.encode(item.get("id", ""))) 1884 else: 1885 # Plain text content 1886 num_tokens += len(self.encoding.encode(str(item))) 1887 1888 # Count tokens in tool calls for OpenAI format 1889 if msg.get("tool_calls"): 1890 for tool_call in msg.get("tool_calls", []): 1891 if isinstance(tool_call, dict): 1892 # Count tokens for function name 1893 func_name = tool_call.get("function", {}).get("name", "") 1894 num_tokens += len(self.encoding.encode(func_name)) 1895 1896 # Count tokens for arguments 1897 args = tool_call.get("function", {}).get("arguments", "") 1898 if isinstance(args, str): 1899 num_tokens += len(self.encoding.encode(args)) 1900 else: 1901 num_tokens += len(self.encoding.encode(json.dumps(args))) 1902 1903 # Add tokens for id and type fields 1904 num_tokens += len(self.encoding.encode(tool_call.get("id", ""))) 1905 num_tokens += len(self.encoding.encode(tool_call.get("type", "function"))) 1906 1907 # Count tokens in Anthropic tool_use field 1908 if msg.get("tool_use"): 1909 tool_use = msg.get("tool_use") 1910 # Count tokens for name 1911 num_tokens += len(self.encoding.encode(tool_use.get("name", ""))) 1912 1913 # Count tokens for input 1914 tool_input = tool_use.get("input", {}) 1915 if isinstance(tool_input, str): 1916 num_tokens += len(self.encoding.encode(tool_input)) 1917 else: 1918 num_tokens += len(self.encoding.encode(json.dumps(tool_input))) 1919 1920 # Add tokens for id field 1921 num_tokens += len(self.encoding.encode(tool_use.get("id", ""))) 1922 1923 # Handle tool response message format 1924 if msg.get("role") == "tool": 1925 # Add tokens for tool_call_id 1926 tool_id = msg.get("tool_call_id", "") 1927 num_tokens += len(self.encoding.encode(tool_id)) 1928 1929 # Add message end tokens 1930 num_tokens += 2 1931 1932 # Cache the result for future use 1933 if use_cache and 'cache_key' in locals(): 1934 self._token_count_cache[cache_key] = num_tokens 1935 1936 return num_tokens 1937 1938 1939 def _cycle_messages(self): 1940 """Intelligently trim the message history to fit within the allowed context length. 1941 1942 This method implements a sophisticated trimming strategy that: 1943 1. Always preserves system messages 1944 2. Always keeps the most recent complete conversation turn 1945 3. Prioritizes keeping tool call chains intact 1946 4. Preserves important context from earlier exchanges 1947 5. Aggressively prunes redundant information before essential content 1948 1949 The method uses a multi-pass approach: 1950 1. First pass: Identify critical messages that must be kept 1951 2. Second pass: Calculate token counts for each message 1952 3. Third pass: Keep important tool chains intact 1953 4. Fourth pass: Fill remaining space with recent messages 1954 1955 Returns: 1956 bool: True if all messages were trimmed (context exceeded), 1957 False if trimming was successful and context is within limits 1958 1959 Note: 1960 This method maintains both session_history and history in sync, 1961 ensuring proper SDK-specific formatting is preserved. 1962 """ 1963 # Check if we need to trim 1964 token_count = self._count_tokens(self.history) 1965 1966 # If we're already under the limit, return early 1967 if token_count <= self.context_length: 1968 return False 1969 1970 self._log(f"[TRIM] Starting message cycling: {token_count} tokens exceeds {self.context_length} limit", level="info") 1971 1972 # We'll need to track tokens as we reconstruct the history 1973 target_tokens = max(self.context_length * 0.8, self.context_length - 1000) # Target 80% or 1000 less than max 1974 1975 # First pass: identify critical messages we must keep 1976 must_keep = [] 1977 tool_chain_groups = {} # Group related tool calls and their results 1978 1979 # Always keep system messages (should be first) 1980 system_indices = [] 1981 for i, msg in enumerate(self.history): 1982 if msg.get("role") == "system": 1983 system_indices.append(i) 1984 must_keep.append(i) 1985 1986 # Identify the most recent complete exchange (user question + assistant response) 1987 latest_exchange = [] 1988 # Start from the end and work backward to find the last complete exchange 1989 for i in range(len(self.history) - 1, -1, -1): 1990 msg = self.history[i] 1991 if msg.get("role") == "assistant" and not latest_exchange: 1992 latest_exchange.append(i) 1993 elif msg.get("role") == "user" and latest_exchange: 1994 latest_exchange.append(i) 1995 break 1996 1997 # Add the latest exchange to must-keep 1998 must_keep.extend(latest_exchange) 1999 2000 # Identify tool chains - track which messages belong to the same tool flow 2001 tool_id_to_chain = {} 2002 for i, msg in enumerate(self.history): 2003 # For assistant messages with tool calls 2004 if msg.get("role") == "assistant" and msg.get("tool_calls"): 2005 for tool_call in msg.get("tool_calls"): 2006 tool_id = tool_call.get("id") 2007 if tool_id: 2008 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2009 2010 # For tool response messages 2011 elif msg.get("role") == "tool" and msg.get("tool_call_id"): 2012 tool_id = msg.get("tool_call_id") 2013 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2014 2015 # For Anthropic format with tool use 2016 elif msg.get("role") == "assistant" and isinstance(msg.get("content"), list): 2017 for block in msg.get("content", []): 2018 if isinstance(block, dict) and block.get("type") == "tool_use": 2019 tool_id = block.get("id") 2020 if tool_id: 2021 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2022 2023 # For Anthropic tool result messages 2024 elif msg.get("role") == "user" and isinstance(msg.get("content"), list): 2025 for block in msg.get("content", []): 2026 if isinstance(block, dict) and block.get("type") == "tool_result": 2027 tool_id = block.get("tool_use_id") 2028 if tool_id: 2029 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2030 2031 # Group together all indices for each tool chain 2032 for tool_id, indices in tool_id_to_chain.items(): 2033 chain_key = f"tool_{min(indices)}" # Group by the earliest message 2034 if chain_key not in tool_chain_groups: 2035 tool_chain_groups[chain_key] = set() 2036 tool_chain_groups[chain_key].update(indices) 2037 2038 # Second pass: calculate tokens for each message 2039 message_tokens = [] 2040 for i, msg in enumerate(self.history): 2041 # Count tokens for this individual message 2042 tokens = self._count_tokens([msg]) 2043 message_tokens.append((i, tokens)) 2044 2045 # Keep the messages identified as must-keep 2046 keep_indices = set(must_keep) 2047 2048 # Calculate the tokens we've committed to keeping 2049 keep_tokens = sum(tokens for i, tokens in message_tokens if i in keep_indices) 2050 2051 # Check if we've already exceeded the target with just must-keep messages 2052 if keep_tokens > self.context_length: 2053 # We're in trouble - the essential messages alone exceed context 2054 # Drop older messages until we're under the limit 2055 all_indices = sorted(keep_indices) 2056 2057 # Start dropping oldest messages, but NEVER drop system messages 2058 for idx in all_indices: 2059 if idx not in system_indices: 2060 keep_indices.remove(idx) 2061 keep_tokens -= message_tokens[idx][1] 2062 if keep_tokens <= target_tokens: 2063 break 2064 2065 # If we've removed everything but system messages and still over limit 2066 if keep_tokens > self.context_length: 2067 self._log(f"[TRIM] Critical failure: even with minimal context ({keep_tokens} tokens), we exceed the limit", level="error") 2068 # Keep only system messages if any 2069 keep_indices = set(system_indices) 2070 return True # Context exceeded completely 2071 2072 # Third pass: keep the most important tool chains intact 2073 available_tokens = target_tokens - keep_tokens 2074 # Sort tool chains by recency (assumed by the chain_key which uses the earliest message) 2075 sorted_chains = sorted(tool_chain_groups.items(), key=lambda x: x[0], reverse=True) 2076 2077 for chain_key, indices in sorted_chains: 2078 # Skip if we've already decided to keep all messages in this chain 2079 if indices.issubset(keep_indices): 2080 continue 2081 2082 # Calculate how many tokens this chain would add 2083 chain_tokens = sum(tokens for i, tokens in message_tokens if i in indices and i not in keep_indices) 2084 2085 # If we can fit the entire chain, keep it 2086 if chain_tokens <= available_tokens: 2087 keep_indices.update(indices) 2088 available_tokens -= chain_tokens 2089 # Otherwise, we might want to keep partial chains in the future, but for now, skip 2090 2091 # Fourth pass: fill in with as many remaining messages as possible, prioritizing recency 2092 # Get remaining messages sorted by recency (newest first) 2093 remaining_indices = [(i, tokens) for i, tokens in message_tokens if i not in keep_indices] 2094 remaining_indices.sort(reverse=True) # Sort newest first 2095 2096 for i, tokens in remaining_indices: 2097 if tokens <= available_tokens: 2098 keep_indices.add(i) 2099 available_tokens -= tokens 2100 2101 # Final message reconstruction 2102 self._log(f"[TRIM] Keeping {len(keep_indices)}/{len(self.history)} messages, estimated {target_tokens - available_tokens} tokens", level="info") 2103 2104 # Create new history with just the kept messages, preserving order 2105 new_history = [self.history[i] for i in sorted(keep_indices)] 2106 self.history = new_history 2107 2108 # Update session_history to match the pruned history 2109 if hasattr(self, "session_history"): 2110 # Map between history items and session_history 2111 session_to_keep = [] 2112 2113 # For each session history message, check if it corresponds to a kept message 2114 for session_msg in self.session_history: 2115 # Keep system messages 2116 if session_msg.get("role") == "system": 2117 session_to_keep.append(session_msg) 2118 continue 2119 2120 # Try to match based on available IDs or content 2121 msg_id = session_msg.get("id") 2122 2123 # For tool messages, check tool_info.id against tool_call_id 2124 if "metadata" in session_msg and "tool_info" in session_msg["metadata"]: 2125 tool_id = session_msg["metadata"]["tool_info"].get("id") 2126 2127 # Check if this tool_id is still in the kept history 2128 for history_msg in new_history: 2129 # Check standard ids 2130 history_tool_id = None 2131 2132 # Check OpenAI format 2133 if history_msg.get("role") == "tool": 2134 history_tool_id = history_msg.get("tool_call_id") 2135 elif history_msg.get("role") == "assistant" and history_msg.get("tool_calls"): 2136 for call in history_msg.get("tool_calls", []): 2137 if call.get("id") == tool_id: 2138 history_tool_id = call.get("id") 2139 break 2140 2141 # Check Anthropic format 2142 elif isinstance(history_msg.get("content"), list): 2143 for block in history_msg.get("content", []): 2144 if isinstance(block, dict): 2145 if block.get("type") == "tool_use" and block.get("id") == tool_id: 2146 history_tool_id = block.get("id") 2147 break 2148 elif block.get("type") == "tool_result" and block.get("tool_use_id") == tool_id: 2149 history_tool_id = block.get("tool_use_id") 2150 break 2151 2152 if history_tool_id == tool_id: 2153 session_to_keep.append(session_msg) 2154 break 2155 2156 # For regular messages, try content matching as fallback 2157 else: 2158 content_match = False 2159 if isinstance(session_msg.get("content"), str) and session_msg.get("content"): 2160 for history_msg in new_history: 2161 if history_msg.get("role") == session_msg.get("role") and history_msg.get("content") == session_msg.get("content"): 2162 content_match = True 2163 break 2164 2165 if content_match: 2166 session_to_keep.append(session_msg) 2167 2168 # Update session_history with kept messages 2169 self.session_history = session_to_keep 2170 2171 # Re-normalize to ensure consistency 2172 self._normalizer(force=True) 2173 2174 # Verify our final token count 2175 final_token_count = self._count_tokens(self.history) 2176 self._log(f"[TRIM] Final history has {len(self.history)} messages, {final_token_count} tokens", level="info") 2177 2178 # Return whether we've completely exceeded context 2179 return final_token_count > self.context_length or len(self.history) == len(system_indices) 2180 2181 2182 def messages_add( 2183 self, 2184 role: str, 2185 content: Any, 2186 tool_info: Optional[Dict] = None, 2187 normalize: bool = True 2188 ) -> str: 2189 """ 2190 Add a message to the standardized session_history and then update SDK-specific history. 2191 2192 This method is the central point for all message additions to the conversation. 2193 2194 Args: 2195 role: The role of the message ("user", "assistant", "system", "tool") 2196 content: The message content (text or structured) 2197 tool_info: Optional tool-related metadata 2198 normalize: Whether to normalize history after adding this message 2199 2200 Returns: 2201 str: Unique ID of the added message 2202 """ 2203 # Generate a unique message ID 2204 message_id = str(uuid.uuid4()) 2205 2206 # Create the standardized message for session_history 2207 timestamp = datetime.now(timezone.utc).isoformat() 2208 2209 # Store system messages directly 2210 if role == "system": 2211 self.system = content 2212 2213 # Create standard format message 2214 standard_message = { 2215 "role": role, 2216 "content": content, 2217 "id": message_id, 2218 "timestamp": timestamp, 2219 "metadata": { 2220 "sdk": self.sdk 2221 } 2222 } 2223 2224 # Add tool info if provided 2225 if tool_info: 2226 standard_message["metadata"]["tool_info"] = tool_info 2227 2228 # Add to session_history 2229 if not hasattr(self, "session_history"): 2230 self.session_history = [] 2231 2232 self.session_history.append(standard_message) 2233 2234 # Save to persistent session if enabled 2235 if self.session_enabled and self.session_id: 2236 # Convert standard message to session-compatible format 2237 session_msg = { 2238 "role": role, 2239 "content": content, 2240 "id": message_id, 2241 "timestamp": timestamp 2242 } 2243 2244 # Add tool-related fields if present 2245 if tool_info: 2246 for key, value in tool_info.items(): 2247 session_msg[key] = value 2248 2249 # Store in session 2250 self.session.msg_insert(self.session_id, session_msg) 2251 2252 # Update the SDK-specific format in self.history by running the normalizer 2253 if normalize: 2254 # We only need to normalize the most recent message for efficiency 2255 # Pass a flag indicating we're just normalizing a new message 2256 self._normalizer(force=False, new_message_only=True) 2257 2258 # Log the added message 2259 self._log(f"[MESSAGE ADDED] {role}: {str(content)[:50]}...") 2260 2261 return message_id 2262 2263 2264 def messages_system(self, prompt: str): 2265 """Set or retrieve the current system prompt. 2266 2267 This method manages the system prompt that guides the AI's behavior. 2268 It can be used to both set a new system prompt and retrieve the current one. 2269 When setting a new prompt, it updates the system message in the conversation 2270 history and persists it to the session if enabled. 2271 2272 Args: 2273 prompt (str): The new system prompt to set. If empty or None, 2274 returns the current system prompt without changes. 2275 2276 Returns: 2277 str: The current system prompt after any updates 2278 """ 2279 if not isinstance(prompt, str) or not prompt: 2280 return self.system 2281 2282 # If the prompt hasn't changed, don't do anything 2283 if self.system == prompt: 2284 return self.system 2285 2286 # Update the system prompt 2287 old_system = self.system 2288 self.system = prompt 2289 2290 # For OpenAI, update or insert the system message in history 2291 if self.sdk == "openai": 2292 # Check if there's already a system message 2293 system_index = next((i for i, msg in enumerate(self.history) 2294 if msg.get("role") == "system"), None) 2295 2296 if system_index is not None: 2297 # Update existing system message 2298 self.history[system_index]["content"] = prompt 2299 else: 2300 # Insert new system message at the beginning 2301 self.history.insert(0, {"role": "system", "content": prompt}) 2302 2303 # For Anthropic, system message is not part of history, just save it for API calls 2304 2305 # Log to session only if prompt actually changed 2306 if self.session_enabled and self.session_id and old_system != prompt: 2307 self.session.msg_insert(self.session_id, {"role": "system", "content": prompt}) 2308 2309 return self.system 2310 2311 2312 def messages(self) -> list: 2313 """Return full session messages (persisted or in-memory).""" 2314 if self.session_enabled and self.session_id: 2315 return self.session.load_full(self.session_id).get("messages", []) 2316 return self.session_history 2317 2318 2319 def messages_length(self) -> int: 2320 """Calculate the total token count for the message history.""" 2321 if not self.encoding: 2322 return 0 2323 2324 total_tokens = 0 2325 for message in self.history: 2326 if message.get("content"): 2327 total_tokens += len(self.encoding.encode(message["content"])) 2328 if message.get("tool_calls"): 2329 for tool_call in message["tool_calls"]: 2330 if tool_call.get("function"): 2331 total_tokens += len(self.encoding.encode(tool_call["function"].get("name", ""))) 2332 total_tokens += len(self.encoding.encode(tool_call["function"].get("arguments", ""))) 2333 return total_tokens 2334 2335 2336 def session_load(self, session_id: Optional[str]): 2337 """Load and normalize messages for a specific session. 2338 2339 This method loads a conversation session from persistent storage and 2340 normalizes the messages to the current SDK format. It handles system 2341 messages, tool calls, and maintains message ordering. If loading fails, 2342 it resets to an empty session with the default system prompt. 2343 2344 Args: 2345 session_id (Optional[str]): The ID of the session to load. 2346 If None, resets to in-memory mode. 2347 2348 Note: 2349 This method will update both session_history and history to match 2350 the loaded session's state. It also ensures proper SDK-specific 2351 message formatting. 2352 """ 2353 self.session_id = session_id 2354 self._last_session_id = session_id 2355 2356 if self.session_enabled and session_id: 2357 try: 2358 # Load raw session data 2359 session_data = self.session.load_full(session_id) 2360 messages = session_data.get("messages", []) 2361 2362 # Convert session format to our standard format 2363 self.session_history = [] 2364 2365 # Track the most recent system message 2366 latest_system_msg = None 2367 2368 for msg in messages: 2369 # Extract fields 2370 role = msg.get("role", "user") 2371 content = msg.get("content", "") 2372 msg_id = msg.get("id", str(uuid.uuid4())) 2373 timestamp = msg.get("timestamp", datetime.now(timezone.utc).isoformat()) 2374 2375 # If this is a system message, track it but don't add to session_history yet 2376 if role == "system": 2377 if latest_system_msg is None or timestamp > latest_system_msg["timestamp"]: 2378 latest_system_msg = { 2379 "role": role, 2380 "content": content, 2381 "id": msg_id, 2382 "timestamp": timestamp, 2383 "metadata": {"sdk": self.sdk} 2384 } 2385 continue 2386 2387 # Build tool_info if present 2388 tool_info = None 2389 if any(key in msg for key in ["tool_use_id", "tool_call_id", "name", "arguments"]): 2390 tool_info = { 2391 "id": msg.get("tool_use_id") or msg.get("tool_call_id"), 2392 "name": msg.get("name", "unknown_tool"), 2393 "arguments": msg.get("arguments", {}) 2394 } 2395 2396 # Create standard message 2397 standard_msg = { 2398 "role": role, 2399 "content": content, 2400 "id": msg_id, 2401 "timestamp": timestamp, 2402 "metadata": { 2403 "sdk": self.sdk 2404 } 2405 } 2406 2407 if tool_info: 2408 standard_msg["metadata"]["tool_info"] = tool_info 2409 2410 self.session_history.append(standard_msg) 2411 2412 # If we found a system message, update the system property and add to history 2413 if latest_system_msg: 2414 self.system = latest_system_msg["content"] 2415 # Insert at the beginning of session_history 2416 self.session_history.insert(0, latest_system_msg) 2417 else: 2418 # If no system message was found, add the current system message 2419 self.messages_add(role="system", content=self.system) 2420 2421 # Normalize to current SDK format 2422 self._normalizer(force=True) 2423 2424 self._log(f"[SESSION] Switched to session '{session_id}'") 2425 except Exception as e: 2426 self.logger.error(f"Failed to load session '{session_id}': {e}") 2427 self.session_reset() 2428 else: 2429 # Reset to empty state with system message 2430 self.session_reset() 2431 2432 2433 def session_reset(self): 2434 """Reset the current session state and reinitialize to default system prompt. 2435 2436 This method performs a complete reset of the conversation state: 2437 1. Clears all message history 2438 2. Disables session ID tracking 2439 3. Returns to in-memory mode 2440 4. Reinitializes with the default system prompt 2441 2442 The reset is useful for starting fresh conversations or recovering 2443 from error states. It maintains the basic system configuration while 2444 clearing all conversation context. 2445 """ 2446 self.session_id = None 2447 self._last_session_id = None 2448 2449 # Clear histories 2450 self.session_history = [] 2451 self.history = [] 2452 2453 # Reapply the system message 2454 if hasattr(self, "system") and self.system: 2455 # Add to session_history 2456 self.messages_add(role="system", content=self.system) 2457 else: 2458 # Ensure we have a default system message 2459 self.system = "You are a helpful Assistant." 2460 self.messages_add(role="system", content=self.system) 2461 2462 self._log("[SESSION] Reset to in-memory mode") 2463 2464 2465 def _normalizer(self, force=False, new_message_only=False): 2466 """Central normalization function for message format conversion. 2467 2468 This method transforms the standardized session_history into the 2469 SDK-specific format needed in self.history. It handles different 2470 message types (system, user, assistant, tool) and their various 2471 formats across different SDKs. 2472 2473 Args: 2474 force (bool): If True, always normalize even if SDK hasn't changed. 2475 Default is False, which only normalizes on SDK change. 2476 new_message_only (bool): If True, only normalize the most recent message 2477 for efficiency when adding single messages. 2478 2479 Note: 2480 This method is the central point for message format conversion and 2481 ensures consistency between session storage and API communication. 2482 """ 2483 # Skip normalization if SDK hasn't changed and force is False 2484 if not force and hasattr(self, '_last_sdk') and self._last_sdk == self.sdk: 2485 # If we only need to normalize the most recent message 2486 if new_message_only and self.session_history: 2487 # Get the most recent message from session_history 2488 recent_msg = self.session_history[-1] 2489 2490 # Apply SDK-specific normalization for just this message 2491 if self.sdk == "openai": 2492 self._openai_normalize_message(recent_msg) 2493 elif self.sdk == "anthropic": 2494 self._anthropic_normalize_message(recent_msg) 2495 else: 2496 # Generic handler for unknown SDKs 2497 self._generic_normalize_message(recent_msg) 2498 2499 return 2500 2501 # Record the current SDK to detect future changes 2502 self._last_sdk = self.sdk 2503 2504 # For full normalization, clear current history and rebuild it 2505 self.history = [] 2506 2507 # Call the appropriate SDK-specific normalizer 2508 if self.sdk == "openai": 2509 self._openai_normalizer() 2510 elif self.sdk == "anthropic": 2511 self._anthropic_normalizer() 2512 else: 2513 self.logger.warning(f"No normalizer available for SDK: {self.sdk}") 2514 # Fallback to a simple conversion for unknown SDKs 2515 for msg in self.session_history: 2516 self._generic_normalize_message(msg) 2517 2518 2519 def _openai_normalizer(self): 2520 """Convert standardized session_history to OpenAI-compatible format. 2521 2522 This method transforms the internal message format into the structure 2523 required by the OpenAI API. It handles: 2524 - System messages at the start of history 2525 - User messages with plain text 2526 - Assistant messages with optional tool calls 2527 - Tool response messages with tool_call_id 2528 2529 The resulting format matches OpenAI's chat completion API requirements 2530 for both regular messages and function calling. 2531 """ 2532 # For OpenAI, we need to include system message in the history 2533 # and convert tool calls/results to OpenAI format 2534 2535 # Start with empty history 2536 self.history = [] 2537 2538 # First, add the current system message at position 0 2539 self.history.append({ 2540 "role": "system", 2541 "content": self.system 2542 }) 2543 2544 # Process all non-system messages 2545 for msg in self.session_history: 2546 if msg["role"] == "system": 2547 continue # Skip system messages, already handled 2548 2549 # Handle different message types 2550 if msg["role"] == "user": 2551 # User messages are straightforward 2552 self.history.append({ 2553 "role": "user", 2554 "content": msg["content"] 2555 }) 2556 2557 elif msg["role"] == "assistant": 2558 # For assistant messages with tool calls 2559 if "metadata" in msg and msg["metadata"].get("tool_info"): 2560 # This is an assistant message with tool calls 2561 tool_info = msg["metadata"]["tool_info"] 2562 2563 # Create OpenAI assistant message with tool calls 2564 assistant_msg = { 2565 "role": "assistant", 2566 "content": msg["content"] if isinstance(msg["content"], str) else "", 2567 "tool_calls": [{ 2568 "id": tool_info["id"], 2569 "type": "function", 2570 "function": { 2571 "name": tool_info["name"], 2572 "arguments": json.dumps(tool_info["arguments"]) if isinstance(tool_info["arguments"], dict) else tool_info["arguments"] 2573 } 2574 }] 2575 } 2576 self.history.append(assistant_msg) 2577 else: 2578 # Regular assistant message 2579 self.history.append({ 2580 "role": "assistant", 2581 "content": msg["content"] 2582 }) 2583 2584 elif msg["role"] == "tool": 2585 # Tool response messages 2586 if "metadata" in msg and "tool_info" in msg["metadata"]: 2587 tool_msg = { 2588 "role": "tool", 2589 "tool_call_id": msg["metadata"]["tool_info"]["id"], 2590 "content": json.dumps(msg["content"]) if isinstance(msg["content"], (dict, list)) else msg["content"] 2591 } 2592 self.history.append(tool_msg) 2593 2594 2595 def _anthropic_normalizer(self): 2596 """Convert standardized session_history to Anthropic-compatible format. 2597 2598 This method transforms the internal message format into the structure 2599 required by the Anthropic API. It handles: 2600 - System messages (stored separately, not in history) 2601 - User messages with optional tool results 2602 - Assistant messages with optional tool use 2603 - Content blocks for structured responses 2604 2605 The resulting format matches Anthropic's message API requirements 2606 for both regular messages and tool use. 2607 """ 2608 # For Anthropic, we don't include system message in the history 2609 # but need to handle content blocks for tool use/results 2610 2611 # Start with empty history 2612 self.history = [] 2613 2614 # Process all non-system messages 2615 for msg in self.session_history: 2616 if msg["role"] == "system": 2617 # Update system prompt if this is the most recent system message 2618 # (only apply the most recent system message if we have multiple) 2619 if msg == self.session_history[-1] or all(m["role"] != "system" for m in self.session_history[self.session_history.index(msg)+1:]): 2620 self.system = msg["content"] 2621 continue # Skip system messages in history 2622 2623 # Handle different message types 2624 if msg["role"] == "user": 2625 # User messages - check if it contains tool results 2626 if "metadata" in msg and "tool_info" in msg["metadata"] and msg["metadata"]["tool_info"].get("result"): 2627 # This is a tool result message 2628 tool_info = msg["metadata"]["tool_info"] 2629 2630 # Create Anthropic tool result format 2631 tool_result_msg = { 2632 "role": "user", 2633 "content": [{ 2634 "type": "tool_result", 2635 "tool_use_id": tool_info["id"], 2636 "content": json.dumps(tool_info["result"]) if isinstance(tool_info["result"], (dict, list)) else str(tool_info["result"]) 2637 }] 2638 } 2639 self.history.append(tool_result_msg) 2640 else: 2641 # Regular user message 2642 self.history.append({ 2643 "role": "user", 2644 "content": msg["content"] 2645 }) 2646 2647 elif msg["role"] == "assistant": 2648 # For assistant messages, check for tool use 2649 if "metadata" in msg and "tool_info" in msg["metadata"]: 2650 # This is an assistant message with tool use 2651 tool_info = msg["metadata"]["tool_info"] 2652 2653 # Build content blocks 2654 content_blocks = [] 2655 2656 # Add text content if present 2657 if msg["content"]: 2658 content_blocks.append({ 2659 "type": "text", 2660 "text": msg["content"] if isinstance(msg["content"], str) else "" 2661 }) 2662 2663 # Add tool use block 2664 content_blocks.append({ 2665 "type": "tool_use", 2666 "id": tool_info["id"], 2667 "name": tool_info["name"], 2668 "input": tool_info["arguments"] if isinstance(tool_info["arguments"], dict) else json.loads(tool_info["arguments"]) 2669 }) 2670 2671 # Create Anthropic assistant message with tool use 2672 self.history.append({ 2673 "role": "assistant", 2674 "content": content_blocks 2675 }) 2676 else: 2677 # Regular assistant message 2678 self.history.append({ 2679 "role": "assistant", 2680 "content": msg["content"] 2681 }) 2682 2683 elif msg["role"] == "tool": 2684 # Tool messages in standard format get converted to user messages with tool_result 2685 if "metadata" in msg and "tool_info" in msg["metadata"]: 2686 tool_info = msg["metadata"]["tool_info"] 2687 2688 # Create Anthropic tool result message 2689 tool_result_msg = { 2690 "role": "user", 2691 "content": [{ 2692 "type": "tool_result", 2693 "tool_use_id": tool_info["id"], 2694 "content": json.dumps(msg["content"]) if isinstance(msg["content"], (dict, list)) else str(msg["content"]) 2695 }] 2696 } 2697 self.history.append(tool_result_msg) 2698 2699 2700 def _openai_normalize_message(self, msg): 2701 """Normalize a single message to OpenAI format and add to history.""" 2702 role = msg.get("role") 2703 content = msg.get("content") 2704 2705 if role == "system": 2706 # Check if we already have a system message in history 2707 system_index = next((i for i, m in enumerate(self.history) 2708 if m.get("role") == "system"), None) 2709 if system_index is not None: 2710 # Update existing system message 2711 self.history[system_index]["content"] = content 2712 else: 2713 # Insert new system message at the beginning 2714 self.history.insert(0, { 2715 "role": "system", 2716 "content": content 2717 }) 2718 # Update the system property 2719 self.system = content 2720 2721 elif role == "user": 2722 self.history.append({ 2723 "role": "user", 2724 "content": content 2725 }) 2726 2727 elif role == "assistant": 2728 # For assistant messages, handle potential tool calls 2729 if "metadata" in msg and msg["metadata"].get("tool_info"): 2730 # This is an assistant message with tool calls 2731 tool_info = msg["metadata"]["tool_info"] 2732 2733 # Create OpenAI assistant message with tool calls 2734 try: 2735 arguments = tool_info.get("arguments", {}) 2736 arguments_str = json.dumps(arguments) if isinstance(arguments, dict) else arguments 2737 except: 2738 arguments_str = str(arguments) 2739 2740 assistant_msg = { 2741 "role": "assistant", 2742 "content": content if isinstance(content, str) else "", 2743 "tool_calls": [{ 2744 "id": tool_info["id"], 2745 "type": "function", 2746 "function": { 2747 "name": tool_info["name"], 2748 "arguments": arguments_str 2749 } 2750 }] 2751 } 2752 self.history.append(assistant_msg) 2753 else: 2754 # Regular assistant message 2755 self.history.append({ 2756 "role": "assistant", 2757 "content": content 2758 }) 2759 2760 elif role == "tool": 2761 # Tool response messages 2762 if "metadata" in msg and "tool_info" in msg["metadata"]: 2763 tool_info = msg["metadata"]["tool_info"] 2764 tool_msg = { 2765 "role": "tool", 2766 "tool_call_id": tool_info["id"], 2767 "content": json.dumps(content) if isinstance(content, (dict, list)) else str(content) 2768 } 2769 self.history.append(tool_msg) 2770 2771 2772 def _anthropic_normalize_message(self, msg): 2773 """Normalize a single message to Anthropic format and add to history.""" 2774 role = msg.get("role") 2775 content = msg.get("content") 2776 2777 if role == "system": 2778 # Store system prompt separately, not in history for Anthropic 2779 self.system = content 2780 2781 elif role == "user": 2782 # User messages - check if it contains tool results 2783 if "metadata" in msg and "tool_info" in msg["metadata"]: 2784 tool_info = msg["metadata"]["tool_info"] 2785 # Check for result or directly use content 2786 result_content = tool_info.get("result", content) 2787 2788 # Create Anthropic tool result format 2789 try: 2790 result_str = json.dumps(result_content) if isinstance(result_content, (dict, list)) else str(result_content) 2791 except: 2792 result_str = str(result_content) 2793 2794 tool_result_msg = { 2795 "role": "user", 2796 "content": [{ 2797 "type": "tool_result", 2798 "tool_use_id": tool_info["id"], 2799 "content": result_str 2800 }] 2801 } 2802 self.history.append(tool_result_msg) 2803 else: 2804 # Regular user message 2805 self.history.append({ 2806 "role": "user", 2807 "content": content 2808 }) 2809 2810 elif role == "assistant": 2811 # For assistant messages, check for tool use 2812 if "metadata" in msg and "tool_info" in msg["metadata"]: 2813 # This is an assistant message with tool use 2814 tool_info = msg["metadata"]["tool_info"] 2815 2816 # Build content blocks 2817 content_blocks = [] 2818 2819 # Add text content if present 2820 if content: 2821 content_blocks.append({ 2822 "type": "text", 2823 "text": content if isinstance(content, str) else "" 2824 }) 2825 2826 # Add tool use block - safely convert arguments 2827 try: 2828 # Parse arguments to ensure it's a dictionary 2829 if isinstance(tool_info["arguments"], str): 2830 try: 2831 args_dict = json.loads(tool_info["arguments"]) 2832 except json.JSONDecodeError: 2833 args_dict = {"text": tool_info["arguments"]} 2834 else: 2835 args_dict = tool_info["arguments"] 2836 except: 2837 args_dict = {"error": "Failed to parse arguments"} 2838 2839 content_blocks.append({ 2840 "type": "tool_use", 2841 "id": tool_info["id"], 2842 "name": tool_info["name"], 2843 "input": args_dict 2844 }) 2845 2846 # Create Anthropic assistant message with tool use 2847 self.history.append({ 2848 "role": "assistant", 2849 "content": content_blocks 2850 }) 2851 else: 2852 # Regular assistant message 2853 self.history.append({ 2854 "role": "assistant", 2855 "content": content 2856 }) 2857 2858 elif role == "tool": 2859 # Tool messages in standard format get converted to user messages with tool_result 2860 if "metadata" in msg and "tool_info" in msg["metadata"]: 2861 tool_info = msg["metadata"]["tool_info"] 2862 2863 try: 2864 result_str = json.dumps(content) if isinstance(content, (dict, list)) else str(content) 2865 except: 2866 result_str = str(content) 2867 2868 # Create Anthropic tool result message 2869 tool_result_msg = { 2870 "role": "user", 2871 "content": [{ 2872 "type": "tool_result", 2873 "tool_use_id": tool_info["id"], 2874 "content": result_str 2875 }] 2876 } 2877 self.history.append(tool_result_msg) 2878 2879 2880 def _generic_normalize_message(self, msg): 2881 """Generic normalizer for unknown SDKs. 2882 2883 This method provides a basic message normalization for SDKs that 2884 don't have specific handling. It performs minimal conversion to 2885 ensure basic message structure is maintained. 2886 2887 Args: 2888 msg (dict): The message to normalize, containing at minimum: 2889 - role: Message role (user/assistant/system) 2890 - content: Message content 2891 2892 Note: 2893 This is a fallback method and should be overridden for specific 2894 SDK implementations when possible. 2895 """ 2896 role = msg.get("role") 2897 content = msg.get("content") 2898 2899 if role in ["user", "assistant", "system"]: 2900 self.history.append({ 2901 "role": role, 2902 "content": content 2903 }) 2904 2905 2906 def track_token_usage(self): 2907 """Track and return token usage across the conversation history. 2908 2909 This method maintains a history of token usage measurements and provides 2910 current usage statistics. It tracks: 2911 - Current token count 2912 - Context length limit 2913 - Usage percentage 2914 - Historical measurements 2915 - Current provider and model info 2916 2917 Returns: 2918 dict: Dictionary containing: 2919 - current: Current token count 2920 - limit: Maximum context length 2921 - percentage: Usage as percentage of limit 2922 - history: Last 10 measurements with timestamps 2923 - provider: Current provider name 2924 - model: Current model name 2925 2926 Note: 2927 The history is limited to the last 100 measurements to prevent 2928 unbounded memory growth. 2929 """ 2930 if not hasattr(self, "_token_history"): 2931 self._token_history = [] 2932 2933 # Count current tokens 2934 current_count = self._count_tokens(self.history) 2935 2936 # Add to history 2937 timestamp = datetime.now(timezone.utc).isoformat() 2938 self._token_history.append({ 2939 "timestamp": timestamp, 2940 "count": current_count, 2941 "limit": self.context_length, 2942 "provider": self.provider, 2943 "model": self.model 2944 }) 2945 2946 # Keep only the last 100 measurements to avoid unlimited growth 2947 if len(self._token_history) > 100: 2948 self._token_history = self._token_history[-100:] 2949 2950 # Return current tracking info 2951 return { 2952 "current": current_count, 2953 "limit": self.context_length, 2954 "percentage": round((current_count / self.context_length) * 100, 1) if self.context_length else 0, 2955 "history": self._token_history[-10:], # Return last 10 measurements 2956 "provider": self.provider, 2957 "model": self.model 2958 } 2959 2960 2961 def get_message_token_breakdown(self): 2962 """Analyze token usage by message type and provide a detailed breakdown. 2963 2964 This method performs a detailed analysis of token usage across the 2965 conversation history, breaking down usage by: 2966 - Message role (system, user, assistant, tool) 2967 - Content type (text, tool calls, tool results) 2968 - Individual message statistics 2969 2970 Returns: 2971 dict: Token usage breakdown containing: 2972 - total: Total tokens used 2973 - by_role: Tokens used by each role 2974 - by_type: Tokens used by content type 2975 - messages: List of individual message stats including: 2976 - index: Message position 2977 - role: Message role 2978 - tokens: Tokens used 2979 - has_tools: Whether message contains tool calls 2980 2981 Note: 2982 This analysis is useful for understanding token usage patterns 2983 and optimizing conversation context. 2984 """ 2985 breakdown = { 2986 "total": 0, 2987 "by_role": { 2988 "system": 0, 2989 "user": 0, 2990 "assistant": 0, 2991 "tool": 0 2992 }, 2993 "by_type": { 2994 "text": 0, 2995 "tool_calls": 0, 2996 "tool_results": 0 2997 }, 2998 "messages": [] 2999 } 3000 3001 # Analyze each message 3002 for i, msg in enumerate(self.history): 3003 msg_tokens = self._count_tokens([msg]) 3004 role = msg.get("role", "unknown") 3005 3006 # Track by role 3007 if role in breakdown["by_role"]: 3008 breakdown["by_role"][role] += msg_tokens 3009 3010 # Track by content type 3011 if role == "assistant" and msg.get("tool_calls"): 3012 breakdown["by_type"]["tool_calls"] += msg_tokens 3013 elif role == "tool": 3014 breakdown["by_type"]["tool_results"] += msg_tokens 3015 else: 3016 breakdown["by_type"]["text"] += msg_tokens 3017 3018 # Add individual message data 3019 breakdown["messages"].append({ 3020 "index": i, 3021 "role": role, 3022 "tokens": msg_tokens, 3023 "has_tools": bool(msg.get("tool_calls") or msg.get("tool_use") or 3024 (isinstance(msg.get("content"), list) and 3025 any(isinstance(c, dict) and c.get("type") in ["tool_use", "tool_result"] 3026 for c in msg.get("content", [])))) 3027 }) 3028 3029 # Update tota? 3030 breakdown["total"] += msg_tokens 3031 3032 return breakdown
65 def __init__( 66 self, 67 base_url: Optional[str] = None, 68 api_key: Optional[str] = None, 69 model: str = "openai:gpt-4o-mini", 70 fallback_model = "ollama:mistral-nemo:latest", 71 tools: Optional[bool] = True, 72 stream: bool = True, 73 quiet: bool = False, 74 context_length: int = 128000, 75 max_retries: int = 3, 76 retry_delay: float = 1.0, 77 log_path: Optional[str] = None, 78 raw: Optional[bool] = False, 79 session_enabled: bool = False, 80 session_id: Optional[str] = None, 81 session_path: Optional[str] = None 82 ): 83 """Initialize the universal AI interaction client. 84 85 Args: 86 base_url: Optional base URL for the API. If None, uses the provider's default URL. 87 api_key: Optional API key. If None, attempts to use environment variables based on provider. 88 model: Model identifier in format "provider:model_name" (e.g., "openai:gpt-4o-mini"). 89 tools: Enable (True) or disable (False) tool calling; None for auto-detection based on model support. 90 stream: Enable (True) or disable (False) streaming responses. 91 context_length: Maximum number of tokens to maintain in conversation history. 92 max_retries: Maximum number of retries for failed API calls. 93 retry_delay: Initial delay (in seconds) for exponential backoff retries. 94 session_enabled: Enable persistent session support. 95 session_id: Optional session ID to load messages from. 96 97 Raises: 98 ValueError: If provider is not supported or API key is missing for non-Ollama providers. 99 """ 100 self.system = "You are a helpful Assistant." 101 self.raw = raw 102 self.quiet = quiet 103 self.logger = logging.getLogger(f"InteractorLogger_{id(self)}") 104 self.logger.setLevel(logging.DEBUG) 105 self.providers = { 106 "openai": { 107 "sdk": "openai", 108 "base_url": "https://api.openai.com/v1", 109 "api_key": api_key or os.getenv("OPENAI_API_KEY") or None 110 }, 111 "ollama": { 112 "sdk": "openai", 113 "base_url": "http://localhost:11434/v1", 114 "api_key": api_key or "ollama" 115 }, 116 "nvidia": { 117 "sdk": "openai", 118 "base_url": "https://integrate.api.nvidia.com/v1", 119 "api_key": api_key or os.getenv("NVIDIA_API_KEY") or None 120 }, 121 "google": { 122 "sdk": "openai", 123 "base_url": "https://generativelanguage.googleapis.com/v1beta/openai", 124 "api_key": api_key or os.getenv("GEMINI_API_KEY") or None 125 }, 126 "anthropic": { 127 "sdk": "anthropic", 128 "base_url": "https://api.anthropic.com/v1", 129 "api_key": api_key or os.getenv("ANTHROPIC_API_KEY") or None 130 }, 131 "mistral": { 132 "sdk": "openai", 133 "base_url": "https://api.mistral.ai/v1", 134 "api_key": api_key or os.getenv("MISTRAL_API_KEY") or None 135 }, 136 "deepseek": { 137 "sdk": "openai", 138 "base_url": "https://api.deepseek.com", 139 "api_key": api_key or os.getenv("DEEPSEEK_API_KEY") or None 140 }, 141 "grok": { 142 "sdk": "grok", 143 "base_url": "https://api.x.ai/v1", 144 "api_key": api_key or os.getenv("GROK_API_KEY") or None 145 } 146 } 147 148 149 # Console log handler (always enabled at WARNING+) 150 if not self.logger.handlers: 151 console_handler = logging.StreamHandler(sys.stdout) 152 console_handler.setLevel(logging.WARNING) 153 console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) 154 self.logger.addHandler(console_handler) 155 156 self._log_enabled = False 157 if log_path: 158 file_handler = logging.FileHandler(log_path) 159 file_handler.setLevel(logging.DEBUG) 160 file_handler.setFormatter(logging.Formatter( 161 "%(asctime)s - %(levelname)s - %(message)s", 162 datefmt="%Y-%m-%d %H:%M:%S" 163 )) 164 self.logger.addHandler(file_handler) 165 self._log_enabled = True 166 167 168 self.token_estimate = 0 169 self.last_token_estimate = 0 170 self.stream = stream 171 self.tools = [] 172 self.session_history = [] 173 self.history = [] 174 self.context_length = context_length 175 self.encoding = None 176 self.max_retries = max_retries 177 self.retry_delay = retry_delay 178 self.reveal_tool = [] 179 self.fallback_model = fallback_model 180 self.sdk = None 181 182 # Session support 183 self.session_enabled = session_enabled 184 self.session_id = session_id 185 self._last_session_id = session_id 186 self.session = Session(directory=session_path) if session_enabled else None 187 188 189 if model is None: 190 model = "openai:gpt-4o-mini" 191 192 # Initialize model + encoding 193 self._setup_client(model, base_url, api_key) 194 self.tools_enabled = self.tools_supported if tools is None else tools and self.tools_supported 195 self._setup_encoding() 196 self.messages_add(role="system", content=self.system)
Initialize the universal AI interaction client.
Args: base_url: Optional base URL for the API. If None, uses the provider's default URL. api_key: Optional API key. If None, attempts to use environment variables based on provider. model: Model identifier in format "provider:model_name" (e.g., "openai:gpt-4o-mini"). tools: Enable (True) or disable (False) tool calling; None for auto-detection based on model support. stream: Enable (True) or disable (False) streaming responses. context_length: Maximum number of tokens to maintain in conversation history. max_retries: Maximum number of retries for failed API calls. retry_delay: Initial delay (in seconds) for exponential backoff retries. session_enabled: Enable persistent session support. session_id: Optional session ID to load messages from.
Raises: ValueError: If provider is not supported or API key is missing for non-Ollama providers.
368 def add_function( 369 self, 370 external_callable: Callable, 371 name: Optional[str] = None, 372 description: Optional[str] = None, 373 override: bool = False, 374 disabled: bool = False, 375 schema_extensions: Optional[Dict[str, Any]] = None 376 ): 377 """ 378 Register a function for LLM tool calling with full type hints and metadata. 379 380 Args: 381 external_callable (Callable): The function to register. 382 name (Optional[str]): Optional custom name. Defaults to function's __name__. 383 description (Optional[str]): Optional custom description. Defaults to first line of docstring. 384 override (bool): If True, replaces an existing tool with the same name. 385 disabled (bool): If True, registers the function in a disabled state. 386 schema_extensions (Optional[Dict[str, Any]]): Optional dictionary mapping parameter names to 387 schema extensions that override or add to the auto-generated schema. 388 389 Raises: 390 ValueError: If the callable is invalid or duplicate name found without override. 391 392 Example: 393 interactor.add_function( 394 my_tool, 395 override=True, 396 disabled=False, 397 schema_extensions={ 398 "param1": {"minimum": 0, "maximum": 100}, 399 "param2": {"format": "email"} 400 } 401 ) 402 """ 403 def _python_type_to_schema(ptype: Any) -> dict: 404 """Convert a Python type annotation to OpenAI-compatible JSON Schema.""" 405 # Handle None case 406 if ptype is None: 407 return {"type": "null"} 408 409 # Get the origin and arguments of the type 410 origin = get_origin(ptype) 411 args = get_args(ptype) 412 413 # Handle Union types (including Optional) 414 if origin is Union: 415 # Check for Optional (Union with None) 416 none_type = type(None) 417 if none_type in args: 418 non_none = [a for a in args if a is not none_type] 419 if len(non_none) == 1: 420 inner = _python_type_to_schema(non_none[0]) 421 inner_copy = inner.copy() 422 inner_copy["nullable"] = True 423 return inner_copy 424 # Multiple types excluding None 425 types = [_python_type_to_schema(a) for a in non_none] 426 return {"anyOf": types, "nullable": True} 427 # Regular Union without None 428 return {"anyOf": [_python_type_to_schema(a) for a in args]} 429 430 # Handle List and similar container types 431 if origin in (list, List): 432 item_type = args[0] if args else Any 433 if item_type is Any: 434 return {"type": "array"} 435 return {"type": "array", "items": _python_type_to_schema(item_type)} 436 437 # Handle Dict types with typing info 438 if origin in (dict, Dict): 439 if not args or len(args) != 2: 440 return {"type": "object"} 441 442 key_type, val_type = args 443 # We can only really use val_type in JSON Schema 444 if val_type is not Any and val_type is not object: 445 return { 446 "type": "object", 447 "additionalProperties": _python_type_to_schema(val_type) 448 } 449 return {"type": "object"} 450 451 # Handle Literal types for enums 452 if origin is Literal: 453 values = args 454 # Try to determine type from values 455 if all(isinstance(v, str) for v in values): 456 return {"type": "string", "enum": list(values)} 457 elif all(isinstance(v, bool) for v in values): 458 return {"type": "boolean", "enum": list(values)} 459 elif all(isinstance(v, (int, float)) for v in values): 460 return {"type": "number", "enum": list(values)} 461 else: 462 # Mixed types, use anyOf 463 return {"anyOf": [{"type": _get_json_type(v), "enum": [v]} for v in values]} 464 465 # Handle basic types 466 if ptype is str: 467 return {"type": "string"} 468 if ptype is int: 469 return {"type": "integer"} 470 if ptype is float: 471 return {"type": "number"} 472 if ptype is bool: 473 return {"type": "boolean"} 474 475 # Handle common datetime types 476 if ptype is datetime: 477 return {"type": "string", "format": "date-time"} 478 if ptype is date: 479 return {"type": "string", "format": "date"} 480 481 # Handle UUID 482 if ptype is uuid.UUID: 483 return {"type": "string", "format": "uuid"} 484 485 # Default to object for any other types 486 return {"type": "object"} 487 488 def _get_json_type(value): 489 """Get the JSON Schema type name for a Python value. 490 491 This helper function maps Python types to their corresponding 492 JSON Schema type names. It handles basic types and provides 493 sensible defaults for complex types. 494 495 Args: 496 value: The Python value to get the JSON type for 497 498 Returns: 499 str: The JSON Schema type name ('string', 'number', 'boolean', 500 'array', 'object', or 'object' as default) 501 """ 502 if isinstance(value, str): 503 return "string" 504 elif isinstance(value, bool): 505 return "boolean" 506 elif isinstance(value, int) or isinstance(value, float): 507 return "number" 508 elif isinstance(value, list): 509 return "array" 510 elif isinstance(value, dict): 511 return "object" 512 else: 513 return "object" # Default 514 515 def _parse_param_docs(docstring: str) -> dict: 516 """Extract parameter descriptions from a docstring.""" 517 if not docstring: 518 return {} 519 520 lines = docstring.splitlines() 521 param_docs = {} 522 current_param = None 523 in_params = False 524 525 # Regular expressions for finding parameter sections and param lines 526 param_section_re = re.compile(r"^(Args|Parameters):\s*$") 527 param_line_re = re.compile(r"^\s{4}(\w+)\s*(?:\([^\)]*\))?:\s*(.*)") 528 529 for line in lines: 530 # Check if we're entering the parameters section 531 if param_section_re.match(line.strip()): 532 in_params = True 533 continue 534 535 if in_params: 536 # Skip empty lines 537 if not line.strip(): 538 continue 539 540 # Check for a parameter definition line 541 match = param_line_re.match(line) 542 if match: 543 current_param = match.group(1) 544 param_docs[current_param] = match.group(2).strip() 545 # Check for continuation of a parameter description 546 elif current_param and line.startswith(" " * 8): 547 param_docs[current_param] += " " + line.strip() 548 # If we see a line that doesn't match our patterns, we're out of the params section 549 else: 550 current_param = None 551 552 return param_docs 553 554 # Start of main function logic 555 556 # Skip if tools are disabled 557 if not self.tools_enabled: 558 return 559 560 # Validate input callable 561 if not external_callable: 562 raise ValueError("A valid external callable must be provided.") 563 564 # Set function name, either from parameter or from callable's __name__ 565 function_name = name or external_callable.__name__ 566 567 # Try to get docstring and extract description 568 try: 569 docstring = inspect.getdoc(external_callable) 570 description = description or (docstring.split("\n")[0].strip() if docstring else "No description provided.") 571 except Exception as e: 572 self._log(f"[TOOL] Warning: Could not extract docstring from {function_name}: {e}", level="warning") 573 docstring = "" 574 description = description or "No description provided." 575 576 # Extract parameter documentation from docstring 577 param_docs = _parse_param_docs(docstring) 578 579 # Handle conflicts with existing functions 580 if override: 581 self.delete_function(function_name) 582 elif any(t["function"]["name"] == function_name for t in self.tools): 583 raise ValueError(f"Function '{function_name}' is already registered. Use override=True to replace.") 584 585 # Try to get function signature for parameter info 586 try: 587 signature = inspect.signature(external_callable) 588 except (ValueError, TypeError) as e: 589 raise ValueError(f"Cannot inspect callable '{function_name}': {e}") 590 591 # Process parameters to build schema 592 properties = {} 593 required = [] 594 595 for param_name, param in signature.parameters.items(): 596 # Skip self, cls parameters for instance/class methods 597 if param_name in ("self", "cls") and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: 598 continue 599 600 # Get parameter annotation, defaulting to Any 601 annotation = param.annotation if param.annotation != inspect.Parameter.empty else Any 602 603 try: 604 # Convert Python type to JSON Schema 605 schema = _python_type_to_schema(annotation) 606 607 # Add description from docstring or create a default one 608 schema["description"] = param_docs.get(param_name, f"{param_name} parameter") 609 610 # Add to properties 611 properties[param_name] = schema 612 613 # If no default value is provided, parameter is required 614 if param.default == inspect.Parameter.empty: 615 required.append(param_name) 616 self._log(f"[TOOL] Parameter '{param_name}' is required", level="debug") 617 else: 618 self._log(f"[TOOL] Parameter '{param_name}' has default value: {param.default}", level="debug") 619 620 except Exception as e: 621 self._log(f"[TOOL] Error processing parameter {param_name} for {function_name}: {e}", level="error") 622 # Add a basic object schema as fallback 623 properties[param_name] = { 624 "type": "string", # Default to string instead of object for better compatibility 625 "description": f"{param_name} parameter (type conversion failed)" 626 } 627 628 # For parameters with no default value, mark as required even if processing failed 629 if param.default == inspect.Parameter.empty: 630 required.append(param_name) 631 self._log(f"[TOOL] Parameter '{param_name}' marked as required despite conversion failure", level="debug") 632 633 # Apply schema extensions if provided 634 if schema_extensions: 635 for param_name, extensions in schema_extensions.items(): 636 if param_name in properties: 637 properties[param_name].update(extensions) 638 639 # Create parameters object with proper placement of 'required' field 640 parameters = { 641 "type": "object", 642 "properties": properties, 643 } 644 645 # Only add required field if there are required parameters 646 if required: 647 parameters["required"] = required 648 649 # Build the final tool specification 650 tool_spec = { 651 "type": "function", 652 "function": { 653 "name": function_name, 654 "description": description, 655 "parameters": parameters 656 } 657 } 658 659 # Set disabled flag if requested 660 if disabled: 661 tool_spec["function"]["disabled"] = True 662 663 # Add to tools list 664 self.tools.append(tool_spec) 665 666 # Make the function available as an attribute on the instance 667 setattr(self, function_name, external_callable) 668 669 # Log the registration with detailed information 670 self._log(f"[TOOL] Registered function '{function_name}' with {len(properties)} parameters", level="info") 671 if required: 672 self._log(f"[TOOL] Required parameters: {required}", level="info") 673 674 return function_name # Return the name for reference
Register a function for LLM tool calling with full type hints and metadata.
Args: external_callable (Callable): The function to register. name (Optional[str]): Optional custom name. Defaults to function's __name__. description (Optional[str]): Optional custom description. Defaults to first line of docstring. override (bool): If True, replaces an existing tool with the same name. disabled (bool): If True, registers the function in a disabled state. schema_extensions (Optional[Dict[str, Any]]): Optional dictionary mapping parameter names to schema extensions that override or add to the auto-generated schema.
Raises: ValueError: If the callable is invalid or duplicate name found without override.
Example: interactor.add_function( my_tool, override=True, disabled=False, schema_extensions={ "param1": {"minimum": 0, "maximum": 100}, "param2": {"format": "email"} } )
677 def disable_function(self, name: str) -> bool: 678 """ 679 Disable a registered tool function by name. 680 681 This marks the function as inactive for tool calling without removing it from the internal registry. 682 The function remains visible in the tool listing but is skipped during tool selection by the LLM. 683 684 Args: 685 name (str): The name of the function to disable. 686 687 Returns: 688 bool: True if the function was found and disabled, False otherwise. 689 690 Example: 691 interactor.disable_function("extract_text") 692 """ 693 for tool in self.tools: 694 if tool["function"]["name"] == name: 695 tool["function"]["disabled"] = True 696 return True 697 return False
Disable a registered tool function by name.
This marks the function as inactive for tool calling without removing it from the internal registry. The function remains visible in the tool listing but is skipped during tool selection by the LLM.
Args: name (str): The name of the function to disable.
Returns: bool: True if the function was found and disabled, False otherwise.
Example: interactor.disable_function("extract_text")
700 def enable_function(self, name: str) -> bool: 701 """ 702 Re-enable a previously disabled tool function by name. 703 704 This removes the 'disabled' flag from a tool function, making it available again for LLM use. 705 706 Args: 707 name (str): The name of the function to enable. 708 709 Returns: 710 bool: True if the function was found and enabled, False otherwise. 711 712 Example: 713 interactor.enable_function("extract_text") 714 """ 715 for tool in self.tools: 716 if tool["function"]["name"] == name: 717 tool["function"].pop("disabled", None) 718 return True 719 return False
Re-enable a previously disabled tool function by name.
This removes the 'disabled' flag from a tool function, making it available again for LLM use.
Args: name (str): The name of the function to enable.
Returns: bool: True if the function was found and enabled, False otherwise.
Example: interactor.enable_function("extract_text")
722 def delete_function(self, name: str) -> bool: 723 """ 724 Permanently remove a registered tool function from the Interactor. 725 726 This deletes both the tool metadata and the callable attribute, making it fully inaccessible 727 from the active session. Useful for dynamically trimming the toolset. 728 729 Args: 730 name (str): The name of the function to delete. 731 732 Returns: 733 bool: True if the function was found and removed, False otherwise. 734 735 Example: 736 interactor.delete_function("extract_text") 737 """ 738 before = len(self.tools) 739 self.tools = [tool for tool in self.tools if tool["function"]["name"] != name] 740 if hasattr(self, name): 741 delattr(self, name) 742 return len(self.tools) < before
Permanently remove a registered tool function from the Interactor.
This deletes both the tool metadata and the callable attribute, making it fully inaccessible from the active session. Useful for dynamically trimming the toolset.
Args: name (str): The name of the function to delete.
Returns: bool: True if the function was found and removed, False otherwise.
Example: interactor.delete_function("extract_text")
745 def list_functions(self) -> List[Dict[str, Any]]: 746 """Get the list of registered functions for tool calling. 747 748 Returns: 749 List[Dict[str, Any]]: List of registered functions. 750 """ 751 return self.tools
Get the list of registered functions for tool calling.
Returns: List[Dict[str, Any]]: List of registered functions.
754 def list_models( 755 self, 756 providers: Optional[Union[str, List[str]]] = None, 757 filter: Optional[str] = None 758 ) -> List[str]: 759 """Retrieve available models from configured providers. 760 761 Args: 762 providers: Provider name or list of provider names. If None, all are queried. 763 filter: Optional regex to filter model names. 764 765 Returns: 766 List[str]: Sorted list of "provider:model_id" strings. 767 """ 768 models = [] 769 770 if providers is None: 771 providers_to_list = self.providers 772 elif isinstance(providers, str): 773 providers_to_list = {providers: self.providers.get(providers)} 774 elif isinstance(providers, list): 775 providers_to_list = {p: self.providers.get(p) for p in providers} 776 else: 777 return [] 778 779 invalid_providers = [p for p in providers_to_list if p not in self.providers or self.providers[p] is None] 780 if invalid_providers: 781 self.logger.error(f"Invalid providers: {invalid_providers}") 782 return [] 783 784 regex_pattern = None 785 if filter: 786 try: 787 regex_pattern = re.compile(filter, re.IGNORECASE) 788 except re.error as e: 789 self.logger.error(f"Invalid regex pattern: {e}") 790 return [] 791 792 for provider_name, config in providers_to_list.items(): 793 sdk = config.get("sdk", "openai") 794 base_url = config.get("base_url") 795 api_key = config.get("api_key") 796 797 try: 798 if sdk == "openai": 799 client = openai.OpenAI(api_key=api_key, base_url=base_url) 800 response = client.models.list() 801 for model in response.data: 802 model_id = f"{provider_name}:{model.id}" 803 if not regex_pattern or regex_pattern.search(model_id): 804 models.append(model_id) 805 806 elif sdk == "anthropic": 807 client = Anthropic(api_key=api_key) 808 response = client.models.list() 809 for model in response: 810 model_id = f"{provider_name}:{model.id}" 811 if not regex_pattern or regex_pattern.search(model_id): 812 models.append(model_id) 813 else: 814 self.logger.warning(f"SDK '{sdk}' for provider '{provider_name}' is not supported by list_models()") 815 816 except Exception as e: 817 self.logger.error(f"Failed to list models for {provider_name}: {e}") 818 819 return sorted(models, key=str.lower)
Retrieve available models from configured providers.
Args: providers: Provider name or list of provider names. If None, all are queried. filter: Optional regex to filter model names.
Returns: List[str]: Sorted list of "provider:model_id" strings.
872 def interact( 873 self, 874 user_input: Optional[str], 875 quiet: bool = False, 876 tools: bool = True, 877 stream: bool = True, 878 markdown: bool = False, 879 model: Optional[str] = None, 880 output_callback: Optional[Callable[[str], None]] = None, 881 session_id: Optional[str] = None, 882 raw: Optional[bool] = None, 883 tool_suppress: bool = True, 884 timeout: float = 60.0 885 ) -> Union[Optional[str], "TokenStream"]: 886 """Main universal gateway for all LLM interaction. 887 888 This function serves as the single entry point for all interactions with the language model. 889 When `raw=False` (default), it handles the interaction internally and returns the full response. 890 When `raw=True`, it returns a context manager that yields chunks of the response for custom handling. 891 892 Args: 893 user_input: Text input from the user. 894 quiet: If True, don't print status info or progress. 895 tools: Enable (True) or disable (False) tool calling. 896 stream: Enable (True) or disable (False) streaming responses. 897 markdown: If True, renders content as markdown. 898 model: Optional model override. 899 output_callback: Optional callback to handle the output. 900 session_id: Optional session ID to load messages from. 901 raw: If True, return a context manager instead of handling the interaction internally. 902 If None, use the class-level setting from __init__. 903 tool_suppress: If True and raw=True, filter out tool-related status messages. 904 timeout: Maximum time in seconds to wait for the stream to complete when raw=True. 905 906 Returns: 907 If raw=False: The complete response from the model as a string, or None if there was an error. 908 If raw=True: A context manager that yields chunks of the response as they arrive. 909 910 Example with default mode: 911 response = ai.interact("Tell me a joke") 912 913 Example with raw mode: 914 with ai.interact("Tell me a joke", raw=True) as stream: 915 for chunk in stream: 916 print(chunk, end="", flush=True) 917 """ 918 if not user_input: 919 return None 920 921 if quiet or self.quiet: 922 markdown = False 923 stream = False 924 925 # Determine if we should use raw mode 926 # If raw parameter is explicitly provided, use that; otherwise use class setting 927 use_raw = self.raw if raw is None else raw 928 929 # If raw mode is requested, delegate to interact_raw 930 if use_raw: 931 return self._interact_raw( 932 user_input=user_input, 933 tools=tools, 934 model=model, 935 session_id=session_id, 936 tool_suppress=tool_suppress, 937 timeout=timeout 938 ) 939 940 # Setup model if specified 941 if model: 942 self._setup_client(model) 943 self._setup_encoding() 944 945 # Session handling 946 if self.session_enabled and session_id: 947 self.session_id = session_id 948 self.session_load(session_id) 949 950 # Add user message using messages_add 951 self.messages_add(role="user", content=user_input) 952 953 # Log token count estimate 954 token_count = self._count_tokens(self.history) 955 if not quiet: 956 print(f"[dim]Estimated tokens in context: {token_count} / {self.context_length}[/dim]") 957 958 # Make sure we have enough context space 959 if token_count > self.context_length: 960 if self._cycle_messages(): 961 if not quiet: 962 print("[red]Context window exceeded. Cannot proceed.[/red]") 963 return None 964 965 # Log user input 966 self._log(f"[USER] {user_input}") 967 968 # Handle the actual interaction with complete streaming for all responses 969 result = asyncio.run(self._interact_async_core( 970 user_input=user_input, 971 quiet=quiet, 972 tools=tools, 973 stream=stream, 974 markdown=markdown, 975 output_callback=output_callback 976 )) 977 978 # Log completion for this interaction 979 self._log(f"[INTERACTION] Completed with {len(self.history)} total messages") 980 981 return result
Main universal gateway for all LLM interaction.
This function serves as the single entry point for all interactions with the language model.
When raw=False
(default), it handles the interaction internally and returns the full response.
When raw=True
, it returns a context manager that yields chunks of the response for custom handling.
Args: user_input: Text input from the user. quiet: If True, don't print status info or progress. tools: Enable (True) or disable (False) tool calling. stream: Enable (True) or disable (False) streaming responses. markdown: If True, renders content as markdown. model: Optional model override. output_callback: Optional callback to handle the output. session_id: Optional session ID to load messages from. raw: If True, return a context manager instead of handling the interaction internally. If None, use the class-level setting from __init__. tool_suppress: If True and raw=True, filter out tool-related status messages. timeout: Maximum time in seconds to wait for the stream to complete when raw=True.
Returns: If raw=False: The complete response from the model as a string, or None if there was an error. If raw=True: A context manager that yields chunks of the response as they arrive.
Example with default mode: response = ai.interact("Tell me a joke")
Example with raw mode: with ai.interact("Tell me a joke", raw=True) as stream: for chunk in stream: print(chunk, end="", flush=True)
2182 def messages_add( 2183 self, 2184 role: str, 2185 content: Any, 2186 tool_info: Optional[Dict] = None, 2187 normalize: bool = True 2188 ) -> str: 2189 """ 2190 Add a message to the standardized session_history and then update SDK-specific history. 2191 2192 This method is the central point for all message additions to the conversation. 2193 2194 Args: 2195 role: The role of the message ("user", "assistant", "system", "tool") 2196 content: The message content (text or structured) 2197 tool_info: Optional tool-related metadata 2198 normalize: Whether to normalize history after adding this message 2199 2200 Returns: 2201 str: Unique ID of the added message 2202 """ 2203 # Generate a unique message ID 2204 message_id = str(uuid.uuid4()) 2205 2206 # Create the standardized message for session_history 2207 timestamp = datetime.now(timezone.utc).isoformat() 2208 2209 # Store system messages directly 2210 if role == "system": 2211 self.system = content 2212 2213 # Create standard format message 2214 standard_message = { 2215 "role": role, 2216 "content": content, 2217 "id": message_id, 2218 "timestamp": timestamp, 2219 "metadata": { 2220 "sdk": self.sdk 2221 } 2222 } 2223 2224 # Add tool info if provided 2225 if tool_info: 2226 standard_message["metadata"]["tool_info"] = tool_info 2227 2228 # Add to session_history 2229 if not hasattr(self, "session_history"): 2230 self.session_history = [] 2231 2232 self.session_history.append(standard_message) 2233 2234 # Save to persistent session if enabled 2235 if self.session_enabled and self.session_id: 2236 # Convert standard message to session-compatible format 2237 session_msg = { 2238 "role": role, 2239 "content": content, 2240 "id": message_id, 2241 "timestamp": timestamp 2242 } 2243 2244 # Add tool-related fields if present 2245 if tool_info: 2246 for key, value in tool_info.items(): 2247 session_msg[key] = value 2248 2249 # Store in session 2250 self.session.msg_insert(self.session_id, session_msg) 2251 2252 # Update the SDK-specific format in self.history by running the normalizer 2253 if normalize: 2254 # We only need to normalize the most recent message for efficiency 2255 # Pass a flag indicating we're just normalizing a new message 2256 self._normalizer(force=False, new_message_only=True) 2257 2258 # Log the added message 2259 self._log(f"[MESSAGE ADDED] {role}: {str(content)[:50]}...") 2260 2261 return message_id
Add a message to the standardized session_history and then update SDK-specific history.
This method is the central point for all message additions to the conversation.
Args: role: The role of the message ("user", "assistant", "system", "tool") content: The message content (text or structured) tool_info: Optional tool-related metadata normalize: Whether to normalize history after adding this message
Returns: str: Unique ID of the added message
2264 def messages_system(self, prompt: str): 2265 """Set or retrieve the current system prompt. 2266 2267 This method manages the system prompt that guides the AI's behavior. 2268 It can be used to both set a new system prompt and retrieve the current one. 2269 When setting a new prompt, it updates the system message in the conversation 2270 history and persists it to the session if enabled. 2271 2272 Args: 2273 prompt (str): The new system prompt to set. If empty or None, 2274 returns the current system prompt without changes. 2275 2276 Returns: 2277 str: The current system prompt after any updates 2278 """ 2279 if not isinstance(prompt, str) or not prompt: 2280 return self.system 2281 2282 # If the prompt hasn't changed, don't do anything 2283 if self.system == prompt: 2284 return self.system 2285 2286 # Update the system prompt 2287 old_system = self.system 2288 self.system = prompt 2289 2290 # For OpenAI, update or insert the system message in history 2291 if self.sdk == "openai": 2292 # Check if there's already a system message 2293 system_index = next((i for i, msg in enumerate(self.history) 2294 if msg.get("role") == "system"), None) 2295 2296 if system_index is not None: 2297 # Update existing system message 2298 self.history[system_index]["content"] = prompt 2299 else: 2300 # Insert new system message at the beginning 2301 self.history.insert(0, {"role": "system", "content": prompt}) 2302 2303 # For Anthropic, system message is not part of history, just save it for API calls 2304 2305 # Log to session only if prompt actually changed 2306 if self.session_enabled and self.session_id and old_system != prompt: 2307 self.session.msg_insert(self.session_id, {"role": "system", "content": prompt}) 2308 2309 return self.system
Set or retrieve the current system prompt.
This method manages the system prompt that guides the AI's behavior. It can be used to both set a new system prompt and retrieve the current one. When setting a new prompt, it updates the system message in the conversation history and persists it to the session if enabled.
Args: prompt (str): The new system prompt to set. If empty or None, returns the current system prompt without changes.
Returns: str: The current system prompt after any updates
2312 def messages(self) -> list: 2313 """Return full session messages (persisted or in-memory).""" 2314 if self.session_enabled and self.session_id: 2315 return self.session.load_full(self.session_id).get("messages", []) 2316 return self.session_history
Return full session messages (persisted or in-memory).
2319 def messages_length(self) -> int: 2320 """Calculate the total token count for the message history.""" 2321 if not self.encoding: 2322 return 0 2323 2324 total_tokens = 0 2325 for message in self.history: 2326 if message.get("content"): 2327 total_tokens += len(self.encoding.encode(message["content"])) 2328 if message.get("tool_calls"): 2329 for tool_call in message["tool_calls"]: 2330 if tool_call.get("function"): 2331 total_tokens += len(self.encoding.encode(tool_call["function"].get("name", ""))) 2332 total_tokens += len(self.encoding.encode(tool_call["function"].get("arguments", ""))) 2333 return total_tokens
Calculate the total token count for the message history.
2336 def session_load(self, session_id: Optional[str]): 2337 """Load and normalize messages for a specific session. 2338 2339 This method loads a conversation session from persistent storage and 2340 normalizes the messages to the current SDK format. It handles system 2341 messages, tool calls, and maintains message ordering. If loading fails, 2342 it resets to an empty session with the default system prompt. 2343 2344 Args: 2345 session_id (Optional[str]): The ID of the session to load. 2346 If None, resets to in-memory mode. 2347 2348 Note: 2349 This method will update both session_history and history to match 2350 the loaded session's state. It also ensures proper SDK-specific 2351 message formatting. 2352 """ 2353 self.session_id = session_id 2354 self._last_session_id = session_id 2355 2356 if self.session_enabled and session_id: 2357 try: 2358 # Load raw session data 2359 session_data = self.session.load_full(session_id) 2360 messages = session_data.get("messages", []) 2361 2362 # Convert session format to our standard format 2363 self.session_history = [] 2364 2365 # Track the most recent system message 2366 latest_system_msg = None 2367 2368 for msg in messages: 2369 # Extract fields 2370 role = msg.get("role", "user") 2371 content = msg.get("content", "") 2372 msg_id = msg.get("id", str(uuid.uuid4())) 2373 timestamp = msg.get("timestamp", datetime.now(timezone.utc).isoformat()) 2374 2375 # If this is a system message, track it but don't add to session_history yet 2376 if role == "system": 2377 if latest_system_msg is None or timestamp > latest_system_msg["timestamp"]: 2378 latest_system_msg = { 2379 "role": role, 2380 "content": content, 2381 "id": msg_id, 2382 "timestamp": timestamp, 2383 "metadata": {"sdk": self.sdk} 2384 } 2385 continue 2386 2387 # Build tool_info if present 2388 tool_info = None 2389 if any(key in msg for key in ["tool_use_id", "tool_call_id", "name", "arguments"]): 2390 tool_info = { 2391 "id": msg.get("tool_use_id") or msg.get("tool_call_id"), 2392 "name": msg.get("name", "unknown_tool"), 2393 "arguments": msg.get("arguments", {}) 2394 } 2395 2396 # Create standard message 2397 standard_msg = { 2398 "role": role, 2399 "content": content, 2400 "id": msg_id, 2401 "timestamp": timestamp, 2402 "metadata": { 2403 "sdk": self.sdk 2404 } 2405 } 2406 2407 if tool_info: 2408 standard_msg["metadata"]["tool_info"] = tool_info 2409 2410 self.session_history.append(standard_msg) 2411 2412 # If we found a system message, update the system property and add to history 2413 if latest_system_msg: 2414 self.system = latest_system_msg["content"] 2415 # Insert at the beginning of session_history 2416 self.session_history.insert(0, latest_system_msg) 2417 else: 2418 # If no system message was found, add the current system message 2419 self.messages_add(role="system", content=self.system) 2420 2421 # Normalize to current SDK format 2422 self._normalizer(force=True) 2423 2424 self._log(f"[SESSION] Switched to session '{session_id}'") 2425 except Exception as e: 2426 self.logger.error(f"Failed to load session '{session_id}': {e}") 2427 self.session_reset() 2428 else: 2429 # Reset to empty state with system message 2430 self.session_reset()
Load and normalize messages for a specific session.
This method loads a conversation session from persistent storage and normalizes the messages to the current SDK format. It handles system messages, tool calls, and maintains message ordering. If loading fails, it resets to an empty session with the default system prompt.
Args: session_id (Optional[str]): The ID of the session to load. If None, resets to in-memory mode.
Note: This method will update both session_history and history to match the loaded session's state. It also ensures proper SDK-specific message formatting.
2433 def session_reset(self): 2434 """Reset the current session state and reinitialize to default system prompt. 2435 2436 This method performs a complete reset of the conversation state: 2437 1. Clears all message history 2438 2. Disables session ID tracking 2439 3. Returns to in-memory mode 2440 4. Reinitializes with the default system prompt 2441 2442 The reset is useful for starting fresh conversations or recovering 2443 from error states. It maintains the basic system configuration while 2444 clearing all conversation context. 2445 """ 2446 self.session_id = None 2447 self._last_session_id = None 2448 2449 # Clear histories 2450 self.session_history = [] 2451 self.history = [] 2452 2453 # Reapply the system message 2454 if hasattr(self, "system") and self.system: 2455 # Add to session_history 2456 self.messages_add(role="system", content=self.system) 2457 else: 2458 # Ensure we have a default system message 2459 self.system = "You are a helpful Assistant." 2460 self.messages_add(role="system", content=self.system) 2461 2462 self._log("[SESSION] Reset to in-memory mode")
Reset the current session state and reinitialize to default system prompt.
This method performs a complete reset of the conversation state:
- Clears all message history
- Disables session ID tracking
- Returns to in-memory mode
- Reinitializes with the default system prompt
The reset is useful for starting fresh conversations or recovering from error states. It maintains the basic system configuration while clearing all conversation context.
2906 def track_token_usage(self): 2907 """Track and return token usage across the conversation history. 2908 2909 This method maintains a history of token usage measurements and provides 2910 current usage statistics. It tracks: 2911 - Current token count 2912 - Context length limit 2913 - Usage percentage 2914 - Historical measurements 2915 - Current provider and model info 2916 2917 Returns: 2918 dict: Dictionary containing: 2919 - current: Current token count 2920 - limit: Maximum context length 2921 - percentage: Usage as percentage of limit 2922 - history: Last 10 measurements with timestamps 2923 - provider: Current provider name 2924 - model: Current model name 2925 2926 Note: 2927 The history is limited to the last 100 measurements to prevent 2928 unbounded memory growth. 2929 """ 2930 if not hasattr(self, "_token_history"): 2931 self._token_history = [] 2932 2933 # Count current tokens 2934 current_count = self._count_tokens(self.history) 2935 2936 # Add to history 2937 timestamp = datetime.now(timezone.utc).isoformat() 2938 self._token_history.append({ 2939 "timestamp": timestamp, 2940 "count": current_count, 2941 "limit": self.context_length, 2942 "provider": self.provider, 2943 "model": self.model 2944 }) 2945 2946 # Keep only the last 100 measurements to avoid unlimited growth 2947 if len(self._token_history) > 100: 2948 self._token_history = self._token_history[-100:] 2949 2950 # Return current tracking info 2951 return { 2952 "current": current_count, 2953 "limit": self.context_length, 2954 "percentage": round((current_count / self.context_length) * 100, 1) if self.context_length else 0, 2955 "history": self._token_history[-10:], # Return last 10 measurements 2956 "provider": self.provider, 2957 "model": self.model 2958 }
Track and return token usage across the conversation history.
This method maintains a history of token usage measurements and provides current usage statistics. It tracks:
- Current token count
- Context length limit
- Usage percentage
- Historical measurements
- Current provider and model info
Returns: dict: Dictionary containing: - current: Current token count - limit: Maximum context length - percentage: Usage as percentage of limit - history: Last 10 measurements with timestamps - provider: Current provider name - model: Current model name
Note: The history is limited to the last 100 measurements to prevent unbounded memory growth.
2961 def get_message_token_breakdown(self): 2962 """Analyze token usage by message type and provide a detailed breakdown. 2963 2964 This method performs a detailed analysis of token usage across the 2965 conversation history, breaking down usage by: 2966 - Message role (system, user, assistant, tool) 2967 - Content type (text, tool calls, tool results) 2968 - Individual message statistics 2969 2970 Returns: 2971 dict: Token usage breakdown containing: 2972 - total: Total tokens used 2973 - by_role: Tokens used by each role 2974 - by_type: Tokens used by content type 2975 - messages: List of individual message stats including: 2976 - index: Message position 2977 - role: Message role 2978 - tokens: Tokens used 2979 - has_tools: Whether message contains tool calls 2980 2981 Note: 2982 This analysis is useful for understanding token usage patterns 2983 and optimizing conversation context. 2984 """ 2985 breakdown = { 2986 "total": 0, 2987 "by_role": { 2988 "system": 0, 2989 "user": 0, 2990 "assistant": 0, 2991 "tool": 0 2992 }, 2993 "by_type": { 2994 "text": 0, 2995 "tool_calls": 0, 2996 "tool_results": 0 2997 }, 2998 "messages": [] 2999 } 3000 3001 # Analyze each message 3002 for i, msg in enumerate(self.history): 3003 msg_tokens = self._count_tokens([msg]) 3004 role = msg.get("role", "unknown") 3005 3006 # Track by role 3007 if role in breakdown["by_role"]: 3008 breakdown["by_role"][role] += msg_tokens 3009 3010 # Track by content type 3011 if role == "assistant" and msg.get("tool_calls"): 3012 breakdown["by_type"]["tool_calls"] += msg_tokens 3013 elif role == "tool": 3014 breakdown["by_type"]["tool_results"] += msg_tokens 3015 else: 3016 breakdown["by_type"]["text"] += msg_tokens 3017 3018 # Add individual message data 3019 breakdown["messages"].append({ 3020 "index": i, 3021 "role": role, 3022 "tokens": msg_tokens, 3023 "has_tools": bool(msg.get("tool_calls") or msg.get("tool_use") or 3024 (isinstance(msg.get("content"), list) and 3025 any(isinstance(c, dict) and c.get("type") in ["tool_use", "tool_result"] 3026 for c in msg.get("content", [])))) 3027 }) 3028 3029 # Update tota? 3030 breakdown["total"] += msg_tokens 3031 3032 return breakdown
Analyze token usage by message type and provide a detailed breakdown.
This method performs a detailed analysis of token usage across the conversation history, breaking down usage by:
- Message role (system, user, assistant, tool)
- Content type (text, tool calls, tool results)
- Individual message statistics
Returns: dict: Token usage breakdown containing: - total: Total tokens used - by_role: Tokens used by each role - by_type: Tokens used by content type - messages: List of individual message stats including: - index: Message position - role: Message role - tokens: Tokens used - has_tools: Whether message contains tool calls
Note: This analysis is useful for understanding token usage patterns and optimizing conversation context.
19class Session: 20 def __init__(self, directory: str = None): 21 """ 22 Initialize the session manager and ensure the session directory exists. 23 24 Args: 25 directory (str): Filesystem path for session storage. Must not be None or empty. 26 27 Raises: 28 ValueError: If directory is None or not a string. 29 OSError: If the directory cannot be created or accessed. 30 """ 31 if not directory: 32 raise ValueError("Session directory must be a valid non-empty string path.") 33 34 try: 35 self.path = Path(os.path.expanduser(directory)) 36 self.path.mkdir(parents=True, exist_ok=True) 37 except Exception as e: 38 raise OSError(f"Failed to initialize session directory '{directory}': {e}") 39 40 # --------------------------- 41 # Core CRUD 42 # --------------------------- 43 44 def list(self) -> List[Dict]: 45 """ 46 Return metadata for all sessions in the directory. 47 48 Returns: 49 List[Dict]: Sorted list of session metadata dictionaries. 50 """ 51 out = [] 52 for file in self.path.glob("*.json"): 53 try: 54 with open(file, "r") as f: 55 d = json.load(f) 56 out.append({ 57 "id": d.get("id"), 58 "name": d.get("name"), 59 "created": d.get("created"), 60 "tags": d.get("tags", []), 61 "summary": d.get("summary") 62 }) 63 except Exception: 64 continue 65 return sorted(out, key=lambda x: x["created"], reverse=True) 66 67 def create(self, name: str, tags: Optional[List[str]] = None) -> str: 68 """ 69 Create and persist a new session. 70 71 Args: 72 name (str): Name of the new session. 73 tags (List[str], optional): Optional list of tags. 74 75 Returns: 76 str: Unique session ID of the new session. 77 """ 78 sid = str(uuid.uuid4()) 79 session = { 80 "id": sid, 81 "name": name, 82 "created": datetime.now(timezone.utc).isoformat(), 83 "parent": None, 84 "branch_point": None, 85 "tags": tags or [], 86 "summary": None, 87 "messages": [] 88 } 89 self._save_file(sid, session) 90 return sid 91 92 def load(self, session_id: str) -> List[Dict]: 93 """ 94 Return OpenAI-compatible message list from a session. 95 96 Filters out internal keys and leaves only standard API-compatible fields. 97 98 Args: 99 session_id (str): ID of the session to load. 100 101 Returns: 102 List[Dict]: List of clean message dictionaries. 103 """ 104 session = self._read_file(session_id) 105 return [ 106 {k: v for k, v in m.items() if k in { 107 "role", "content", "tool_calls", "name", "function_call", "tool_call_id" 108 }} for m in session.get("messages", []) 109 ] 110 111 def load_full(self, session_id: str) -> Dict: 112 """ 113 Return the complete session file as-is. 114 115 Args: 116 session_id (str): ID of the session. 117 118 Returns: 119 Dict: Entire raw session data. 120 """ 121 return self._read_file(session_id) 122 123 def delete(self, session_id: str): 124 """ 125 Delete a session file from disk. 126 127 Args: 128 session_id (str): ID of the session to delete. 129 """ 130 file = self.path / f"{session_id}.json" 131 if file.exists(): 132 file.unlink() 133 134 def update(self, session_id: str, key: str, value: Any): 135 """ 136 Update a top-level key in a session file. 137 138 Args: 139 session_id (str): Session ID. 140 key (str): Field to update. 141 value (Any): New value for the field. 142 """ 143 session = self._read_file(session_id) 144 session[key] = value 145 self._save_file(session_id, session) 146 147 # --------------------------- 148 # Message Operations 149 # --------------------------- 150 151 def msg_insert(self, session_id: str, message: Dict) -> str: 152 """ 153 Insert a new message into a session. 154 155 Args: 156 session_id (str): Session ID. 157 message (Dict): Message dictionary to insert. 158 159 Returns: 160 str: ID of the inserted message. 161 """ 162 session = self._read_file(session_id) 163 entry = { 164 "id": str(uuid.uuid4()), 165 "timestamp": datetime.now(timezone.utc).isoformat(), 166 **message 167 } 168 session["messages"].append(entry) 169 self._save_file(session_id, session) 170 return entry["id"] 171 172 def msg_get(self, session_id: str, message_id: str) -> Optional[Dict]: 173 """ 174 Retrieve a specific message from a session. 175 176 Args: 177 session_id (str): Session ID. 178 message_id (str): ID of the message to retrieve. 179 180 Returns: 181 Optional[Dict]: The message if found, else None. 182 """ 183 session = self._read_file(session_id) 184 for msg in session.get("messages", []): 185 if msg.get("id") == message_id: 186 return msg 187 return None 188 189 def msg_index(self, session_id: str, message_id: str) -> Optional[int]: 190 """ 191 Get the index of a message within a session. 192 193 Args: 194 session_id (str): Session ID. 195 message_id (str): Message ID. 196 197 Returns: 198 Optional[int]: Index if found, else None. 199 """ 200 session = self._read_file(session_id) 201 for i, msg in enumerate(session.get("messages", [])): 202 if msg.get("id") == message_id: 203 return i 204 return None 205 206 def msg_update(self, session_id: str, message_id: str, new_content: str) -> bool: 207 """ 208 Update the content of a specific message. 209 210 Args: 211 session_id (str): Session ID. 212 message_id (str): Message ID. 213 new_content (str): New content for the message. 214 215 Returns: 216 bool: True if update succeeded, False otherwise. 217 """ 218 session = self._read_file(session_id) 219 for m in session["messages"]: 220 if m.get("id") == message_id: 221 m["content"] = new_content 222 self._save_file(session_id, session) 223 return True 224 return False 225 226 def msg_delete(self, session_id: str, message_id: str) -> bool: 227 """ 228 Delete a message from a session. 229 230 Args: 231 session_id (str): Session ID. 232 message_id (str): Message ID. 233 234 Returns: 235 bool: True if deletion occurred, False otherwise. 236 """ 237 session = self._read_file(session_id) 238 before = len(session["messages"]) 239 session["messages"] = [m for m in session["messages"] if m.get("id") != message_id] 240 self._save_file(session_id, session) 241 return len(session["messages"]) < before 242 243 # --------------------------- 244 # Branching & Summarization 245 # --------------------------- 246 247 def branch(self, from_id: str, message_id: str, new_name: str) -> str: 248 """Create a new session by branching from a specific message. 249 250 This method creates a new session that branches from an existing one at a specific 251 message point. The new session inherits all messages up to and including the 252 specified message, then starts fresh from there. 253 254 Args: 255 from_id (str): ID of the source session to branch from. 256 message_id (str): ID of the message to branch at. 257 new_name (str): Name for the new branched session. 258 259 Returns: 260 str: ID of the newly created branched session. 261 262 Raises: 263 ValueError: If the source session or message ID is not found. 264 """ 265 # Get source session 266 source = self._read_file(from_id) 267 if not source: 268 raise ValueError(f"Source session '{from_id}' not found") 269 270 # Find the branch point 271 branch_index = self.msg_index(from_id, message_id) 272 if branch_index is None: 273 raise ValueError(f"Message '{message_id}' not found in session '{from_id}'") 274 275 # Create new session 276 new_id = self.create(new_name, source.get("tags", [])) 277 new_session = self._read_file(new_id) 278 279 # Copy messages up to branch point 280 new_session["messages"] = source["messages"][:branch_index + 1] 281 new_session["parent"] = from_id 282 new_session["branch_point"] = message_id 283 284 # Save and return 285 self._save_file(new_id, new_session) 286 return new_id 287 288 def summarize(self, interactor, session_id: str) -> str: 289 """Generate a summary of the session using the provided interactor. 290 291 This method uses the AI interactor to analyze the session content and generate 292 a concise summary. The summary is stored in the session metadata and returned. 293 294 Args: 295 interactor: An AI interactor instance capable of generating summaries. 296 session_id (str): ID of the session to summarize. 297 298 Returns: 299 str: The generated summary text. 300 301 Note: 302 The summary is automatically stored in the session metadata and can be 303 retrieved later using load_full(). 304 """ 305 session = self._read_file(session_id) 306 if not session: 307 return "" 308 309 # Get clean message list 310 messages = self.load(session_id) 311 if not messages: 312 return "" 313 314 # Generate summary 315 summary = interactor.interact( 316 "Summarize this conversation in 2-3 sentences:", 317 tools=False, 318 stream=False, 319 markdown=False 320 ) 321 322 # Store and return 323 session["summary"] = summary 324 self._save_file(session_id, session) 325 return summary 326 327 # --------------------------- 328 # Search Capabilities 329 # --------------------------- 330 331 def search(self, query: str, session_id: Optional[str] = None) -> List[Dict]: 332 """Search for messages containing the query text within a session or all sessions. 333 334 This method performs a case-insensitive text search across message content. 335 If a session_id is provided, only searches within that session. Otherwise, 336 searches across all sessions. 337 338 Args: 339 query (str): Text to search for. 340 session_id (Optional[str]): Optional session ID to limit search scope. 341 342 Returns: 343 List[Dict]: List of matching messages with their session context. 344 Each dict contains: 345 - session_id: ID of the containing session 346 - message: The matching message 347 - context: Surrounding messages for context 348 """ 349 results = [] 350 query = query.lower() 351 352 # Determine search scope 353 if session_id: 354 sessions = [(session_id, self._read_file(session_id))] 355 else: 356 sessions = [(f.stem, self._read_file(f.stem)) for f in self.path.glob("*.json")] 357 358 # Search each session 359 for sid, session in sessions: 360 if not session: 361 continue 362 363 messages = session.get("messages", []) 364 for i, msg in enumerate(messages): 365 content = str(msg.get("content", "")).lower() 366 if query in content: 367 # Get context (2 messages before and after) 368 start = max(0, i - 2) 369 end = min(len(messages), i + 3) 370 context = messages[start:end] 371 372 results.append({ 373 "session_id": sid, 374 "message": msg, 375 "context": context 376 }) 377 378 return results 379 380 def search_meta(self, query: str) -> List[Dict]: 381 """Search session metadata (name, tags, summary) for matching sessions. 382 383 This method performs a case-insensitive search across session metadata fields 384 including name, tags, and summary. It returns matching sessions with their 385 full metadata. 386 387 Args: 388 query (str): Text to search for in metadata. 389 390 Returns: 391 List[Dict]: List of matching session metadata dictionaries. 392 Each dict contains: 393 - id: Session ID 394 - name: Session name 395 - created: Creation timestamp 396 - tags: List of tags 397 - summary: Session summary if available 398 """ 399 results = [] 400 query = query.lower() 401 402 for file in self.path.glob("*.json"): 403 try: 404 with open(file, "r") as f: 405 session = json.load(f) 406 407 # Check metadata fields 408 name = str(session.get("name", "")).lower() 409 tags = [str(t).lower() for t in session.get("tags", [])] 410 summary = str(session.get("summary", "")).lower() 411 412 if (query in name or 413 any(query in tag for tag in tags) or 414 query in summary): 415 results.append({ 416 "id": session.get("id"), 417 "name": session.get("name"), 418 "created": session.get("created"), 419 "tags": session.get("tags", []), 420 "summary": session.get("summary") 421 }) 422 except Exception: 423 continue 424 425 return sorted(results, key=lambda x: x["created"], reverse=True) 426 427 # --------------------------- 428 # Internal I/O 429 # --------------------------- 430 431 def _read_file(self, session_id: str) -> Dict: 432 """Read and parse a session file from disk. 433 434 This internal method handles reading and parsing session files. 435 It ensures proper error handling and returns an empty session 436 structure if the file doesn't exist or is invalid. 437 438 Args: 439 session_id (str): ID of the session to read. 440 441 Returns: 442 Dict: Session data dictionary or empty session structure. 443 """ 444 file = self.path / f"{session_id}.json" 445 if not file.exists(): 446 return { 447 "id": session_id, 448 "name": "New Session", 449 "created": datetime.now(timezone.utc).isoformat(), 450 "messages": [] 451 } 452 453 try: 454 with open(file, "r") as f: 455 return json.load(f) 456 except Exception: 457 return { 458 "id": session_id, 459 "name": "New Session", 460 "created": datetime.now(timezone.utc).isoformat(), 461 "messages": [] 462 } 463 464 def _save_file(self, session_id: str, data: Dict): 465 """Write session data to disk. 466 467 This internal method handles writing session data to disk. 468 It ensures proper error handling and atomic writes. 469 470 Args: 471 session_id (str): ID of the session to save. 472 data (Dict): Session data to write. 473 474 Raises: 475 OSError: If the file cannot be written. 476 """ 477 file = self.path / f"{session_id}.json" 478 temp_file = file.with_suffix(".tmp") 479 480 try: 481 # Write to temporary file first 482 with open(temp_file, "w") as f: 483 json.dump(data, f, indent=2) 484 485 # Atomic rename 486 temp_file.replace(file) 487 except Exception as e: 488 if temp_file.exists(): 489 temp_file.unlink() 490 raise OSError(f"Failed to save session '{session_id}': {e}")
20 def __init__(self, directory: str = None): 21 """ 22 Initialize the session manager and ensure the session directory exists. 23 24 Args: 25 directory (str): Filesystem path for session storage. Must not be None or empty. 26 27 Raises: 28 ValueError: If directory is None or not a string. 29 OSError: If the directory cannot be created or accessed. 30 """ 31 if not directory: 32 raise ValueError("Session directory must be a valid non-empty string path.") 33 34 try: 35 self.path = Path(os.path.expanduser(directory)) 36 self.path.mkdir(parents=True, exist_ok=True) 37 except Exception as e: 38 raise OSError(f"Failed to initialize session directory '{directory}': {e}")
Initialize the session manager and ensure the session directory exists.
Args: directory (str): Filesystem path for session storage. Must not be None or empty.
Raises: ValueError: If directory is None or not a string. OSError: If the directory cannot be created or accessed.
44 def list(self) -> List[Dict]: 45 """ 46 Return metadata for all sessions in the directory. 47 48 Returns: 49 List[Dict]: Sorted list of session metadata dictionaries. 50 """ 51 out = [] 52 for file in self.path.glob("*.json"): 53 try: 54 with open(file, "r") as f: 55 d = json.load(f) 56 out.append({ 57 "id": d.get("id"), 58 "name": d.get("name"), 59 "created": d.get("created"), 60 "tags": d.get("tags", []), 61 "summary": d.get("summary") 62 }) 63 except Exception: 64 continue 65 return sorted(out, key=lambda x: x["created"], reverse=True)
Return metadata for all sessions in the directory.
Returns: List[Dict]: Sorted list of session metadata dictionaries.
67 def create(self, name: str, tags: Optional[List[str]] = None) -> str: 68 """ 69 Create and persist a new session. 70 71 Args: 72 name (str): Name of the new session. 73 tags (List[str], optional): Optional list of tags. 74 75 Returns: 76 str: Unique session ID of the new session. 77 """ 78 sid = str(uuid.uuid4()) 79 session = { 80 "id": sid, 81 "name": name, 82 "created": datetime.now(timezone.utc).isoformat(), 83 "parent": None, 84 "branch_point": None, 85 "tags": tags or [], 86 "summary": None, 87 "messages": [] 88 } 89 self._save_file(sid, session) 90 return sid
Create and persist a new session.
Args: name (str): Name of the new session. tags (List[str], optional): Optional list of tags.
Returns: str: Unique session ID of the new session.
92 def load(self, session_id: str) -> List[Dict]: 93 """ 94 Return OpenAI-compatible message list from a session. 95 96 Filters out internal keys and leaves only standard API-compatible fields. 97 98 Args: 99 session_id (str): ID of the session to load. 100 101 Returns: 102 List[Dict]: List of clean message dictionaries. 103 """ 104 session = self._read_file(session_id) 105 return [ 106 {k: v for k, v in m.items() if k in { 107 "role", "content", "tool_calls", "name", "function_call", "tool_call_id" 108 }} for m in session.get("messages", []) 109 ]
Return OpenAI-compatible message list from a session.
Filters out internal keys and leaves only standard API-compatible fields.
Args: session_id (str): ID of the session to load.
Returns: List[Dict]: List of clean message dictionaries.
111 def load_full(self, session_id: str) -> Dict: 112 """ 113 Return the complete session file as-is. 114 115 Args: 116 session_id (str): ID of the session. 117 118 Returns: 119 Dict: Entire raw session data. 120 """ 121 return self._read_file(session_id)
Return the complete session file as-is.
Args: session_id (str): ID of the session.
Returns: Dict: Entire raw session data.
123 def delete(self, session_id: str): 124 """ 125 Delete a session file from disk. 126 127 Args: 128 session_id (str): ID of the session to delete. 129 """ 130 file = self.path / f"{session_id}.json" 131 if file.exists(): 132 file.unlink()
Delete a session file from disk.
Args: session_id (str): ID of the session to delete.
134 def update(self, session_id: str, key: str, value: Any): 135 """ 136 Update a top-level key in a session file. 137 138 Args: 139 session_id (str): Session ID. 140 key (str): Field to update. 141 value (Any): New value for the field. 142 """ 143 session = self._read_file(session_id) 144 session[key] = value 145 self._save_file(session_id, session)
Update a top-level key in a session file.
Args: session_id (str): Session ID. key (str): Field to update. value (Any): New value for the field.
151 def msg_insert(self, session_id: str, message: Dict) -> str: 152 """ 153 Insert a new message into a session. 154 155 Args: 156 session_id (str): Session ID. 157 message (Dict): Message dictionary to insert. 158 159 Returns: 160 str: ID of the inserted message. 161 """ 162 session = self._read_file(session_id) 163 entry = { 164 "id": str(uuid.uuid4()), 165 "timestamp": datetime.now(timezone.utc).isoformat(), 166 **message 167 } 168 session["messages"].append(entry) 169 self._save_file(session_id, session) 170 return entry["id"]
Insert a new message into a session.
Args: session_id (str): Session ID. message (Dict): Message dictionary to insert.
Returns: str: ID of the inserted message.
172 def msg_get(self, session_id: str, message_id: str) -> Optional[Dict]: 173 """ 174 Retrieve a specific message from a session. 175 176 Args: 177 session_id (str): Session ID. 178 message_id (str): ID of the message to retrieve. 179 180 Returns: 181 Optional[Dict]: The message if found, else None. 182 """ 183 session = self._read_file(session_id) 184 for msg in session.get("messages", []): 185 if msg.get("id") == message_id: 186 return msg 187 return None
Retrieve a specific message from a session.
Args: session_id (str): Session ID. message_id (str): ID of the message to retrieve.
Returns: Optional[Dict]: The message if found, else None.
189 def msg_index(self, session_id: str, message_id: str) -> Optional[int]: 190 """ 191 Get the index of a message within a session. 192 193 Args: 194 session_id (str): Session ID. 195 message_id (str): Message ID. 196 197 Returns: 198 Optional[int]: Index if found, else None. 199 """ 200 session = self._read_file(session_id) 201 for i, msg in enumerate(session.get("messages", [])): 202 if msg.get("id") == message_id: 203 return i 204 return None
Get the index of a message within a session.
Args: session_id (str): Session ID. message_id (str): Message ID.
Returns: Optional[int]: Index if found, else None.
206 def msg_update(self, session_id: str, message_id: str, new_content: str) -> bool: 207 """ 208 Update the content of a specific message. 209 210 Args: 211 session_id (str): Session ID. 212 message_id (str): Message ID. 213 new_content (str): New content for the message. 214 215 Returns: 216 bool: True if update succeeded, False otherwise. 217 """ 218 session = self._read_file(session_id) 219 for m in session["messages"]: 220 if m.get("id") == message_id: 221 m["content"] = new_content 222 self._save_file(session_id, session) 223 return True 224 return False
Update the content of a specific message.
Args: session_id (str): Session ID. message_id (str): Message ID. new_content (str): New content for the message.
Returns: bool: True if update succeeded, False otherwise.
226 def msg_delete(self, session_id: str, message_id: str) -> bool: 227 """ 228 Delete a message from a session. 229 230 Args: 231 session_id (str): Session ID. 232 message_id (str): Message ID. 233 234 Returns: 235 bool: True if deletion occurred, False otherwise. 236 """ 237 session = self._read_file(session_id) 238 before = len(session["messages"]) 239 session["messages"] = [m for m in session["messages"] if m.get("id") != message_id] 240 self._save_file(session_id, session) 241 return len(session["messages"]) < before
Delete a message from a session.
Args: session_id (str): Session ID. message_id (str): Message ID.
Returns: bool: True if deletion occurred, False otherwise.
247 def branch(self, from_id: str, message_id: str, new_name: str) -> str: 248 """Create a new session by branching from a specific message. 249 250 This method creates a new session that branches from an existing one at a specific 251 message point. The new session inherits all messages up to and including the 252 specified message, then starts fresh from there. 253 254 Args: 255 from_id (str): ID of the source session to branch from. 256 message_id (str): ID of the message to branch at. 257 new_name (str): Name for the new branched session. 258 259 Returns: 260 str: ID of the newly created branched session. 261 262 Raises: 263 ValueError: If the source session or message ID is not found. 264 """ 265 # Get source session 266 source = self._read_file(from_id) 267 if not source: 268 raise ValueError(f"Source session '{from_id}' not found") 269 270 # Find the branch point 271 branch_index = self.msg_index(from_id, message_id) 272 if branch_index is None: 273 raise ValueError(f"Message '{message_id}' not found in session '{from_id}'") 274 275 # Create new session 276 new_id = self.create(new_name, source.get("tags", [])) 277 new_session = self._read_file(new_id) 278 279 # Copy messages up to branch point 280 new_session["messages"] = source["messages"][:branch_index + 1] 281 new_session["parent"] = from_id 282 new_session["branch_point"] = message_id 283 284 # Save and return 285 self._save_file(new_id, new_session) 286 return new_id
Create a new session by branching from a specific message.
This method creates a new session that branches from an existing one at a specific message point. The new session inherits all messages up to and including the specified message, then starts fresh from there.
Args: from_id (str): ID of the source session to branch from. message_id (str): ID of the message to branch at. new_name (str): Name for the new branched session.
Returns: str: ID of the newly created branched session.
Raises: ValueError: If the source session or message ID is not found.
288 def summarize(self, interactor, session_id: str) -> str: 289 """Generate a summary of the session using the provided interactor. 290 291 This method uses the AI interactor to analyze the session content and generate 292 a concise summary. The summary is stored in the session metadata and returned. 293 294 Args: 295 interactor: An AI interactor instance capable of generating summaries. 296 session_id (str): ID of the session to summarize. 297 298 Returns: 299 str: The generated summary text. 300 301 Note: 302 The summary is automatically stored in the session metadata and can be 303 retrieved later using load_full(). 304 """ 305 session = self._read_file(session_id) 306 if not session: 307 return "" 308 309 # Get clean message list 310 messages = self.load(session_id) 311 if not messages: 312 return "" 313 314 # Generate summary 315 summary = interactor.interact( 316 "Summarize this conversation in 2-3 sentences:", 317 tools=False, 318 stream=False, 319 markdown=False 320 ) 321 322 # Store and return 323 session["summary"] = summary 324 self._save_file(session_id, session) 325 return summary
Generate a summary of the session using the provided interactor.
This method uses the AI interactor to analyze the session content and generate a concise summary. The summary is stored in the session metadata and returned.
Args: interactor: An AI interactor instance capable of generating summaries. session_id (str): ID of the session to summarize.
Returns: str: The generated summary text.
Note: The summary is automatically stored in the session metadata and can be retrieved later using load_full().
331 def search(self, query: str, session_id: Optional[str] = None) -> List[Dict]: 332 """Search for messages containing the query text within a session or all sessions. 333 334 This method performs a case-insensitive text search across message content. 335 If a session_id is provided, only searches within that session. Otherwise, 336 searches across all sessions. 337 338 Args: 339 query (str): Text to search for. 340 session_id (Optional[str]): Optional session ID to limit search scope. 341 342 Returns: 343 List[Dict]: List of matching messages with their session context. 344 Each dict contains: 345 - session_id: ID of the containing session 346 - message: The matching message 347 - context: Surrounding messages for context 348 """ 349 results = [] 350 query = query.lower() 351 352 # Determine search scope 353 if session_id: 354 sessions = [(session_id, self._read_file(session_id))] 355 else: 356 sessions = [(f.stem, self._read_file(f.stem)) for f in self.path.glob("*.json")] 357 358 # Search each session 359 for sid, session in sessions: 360 if not session: 361 continue 362 363 messages = session.get("messages", []) 364 for i, msg in enumerate(messages): 365 content = str(msg.get("content", "")).lower() 366 if query in content: 367 # Get context (2 messages before and after) 368 start = max(0, i - 2) 369 end = min(len(messages), i + 3) 370 context = messages[start:end] 371 372 results.append({ 373 "session_id": sid, 374 "message": msg, 375 "context": context 376 }) 377 378 return results
Search for messages containing the query text within a session or all sessions.
This method performs a case-insensitive text search across message content. If a session_id is provided, only searches within that session. Otherwise, searches across all sessions.
Args: query (str): Text to search for. session_id (Optional[str]): Optional session ID to limit search scope.
Returns: List[Dict]: List of matching messages with their session context. Each dict contains: - session_id: ID of the containing session - message: The matching message - context: Surrounding messages for context
380 def search_meta(self, query: str) -> List[Dict]: 381 """Search session metadata (name, tags, summary) for matching sessions. 382 383 This method performs a case-insensitive search across session metadata fields 384 including name, tags, and summary. It returns matching sessions with their 385 full metadata. 386 387 Args: 388 query (str): Text to search for in metadata. 389 390 Returns: 391 List[Dict]: List of matching session metadata dictionaries. 392 Each dict contains: 393 - id: Session ID 394 - name: Session name 395 - created: Creation timestamp 396 - tags: List of tags 397 - summary: Session summary if available 398 """ 399 results = [] 400 query = query.lower() 401 402 for file in self.path.glob("*.json"): 403 try: 404 with open(file, "r") as f: 405 session = json.load(f) 406 407 # Check metadata fields 408 name = str(session.get("name", "")).lower() 409 tags = [str(t).lower() for t in session.get("tags", [])] 410 summary = str(session.get("summary", "")).lower() 411 412 if (query in name or 413 any(query in tag for tag in tags) or 414 query in summary): 415 results.append({ 416 "id": session.get("id"), 417 "name": session.get("name"), 418 "created": session.get("created"), 419 "tags": session.get("tags", []), 420 "summary": session.get("summary") 421 }) 422 except Exception: 423 continue 424 425 return sorted(results, key=lambda x: x["created"], reverse=True)
Search session metadata (name, tags, summary) for matching sessions.
This method performs a case-insensitive search across session metadata fields including name, tags, and summary. It returns matching sessions with their full metadata.
Args: query (str): Text to search for in metadata.
Returns: List[Dict]: List of matching session metadata dictionaries. Each dict contains: - id: Session ID - name: Session name - created: Creation timestamp - tags: List of tags - summary: Session summary if available