interactor
61class Interactor: 62 def __init__( 63 self, 64 base_url: Optional[str] = None, 65 api_key: Optional[str] = None, 66 model: str = "openai:gpt-4o-mini", 67 fallback_model = "ollama:mistral-nemo:latest", 68 tools: Optional[bool] = True, 69 stream: bool = True, 70 context_length: int = 128000, 71 max_retries: int = 3, 72 retry_delay: float = 1.0, 73 log_path: Optional[str] = None, 74 session_enabled: bool = False, 75 session_id: Optional[str] = None, 76 session_path: Optional[str] = None 77 ): 78 """Initialize the universal AI interaction client. 79 80 Args: 81 base_url: Optional base URL for the API. If None, uses the provider's default URL. 82 api_key: Optional API key. If None, attempts to use environment variables based on provider. 83 model: Model identifier in format "provider:model_name" (e.g., "openai:gpt-4o-mini"). 84 tools: Enable (True) or disable (False) tool calling; None for auto-detection based on model support. 85 stream: Enable (True) or disable (False) streaming responses. 86 context_length: Maximum number of tokens to maintain in conversation history. 87 max_retries: Maximum number of retries for failed API calls. 88 retry_delay: Initial delay (in seconds) for exponential backoff retries. 89 session_enabled: Enable persistent session support. 90 session_id: Optional session ID to load messages from. 91 92 Raises: 93 ValueError: If provider is not supported or API key is missing for non-Ollama providers. 94 """ 95 self.system = "You are a helpful Assistant." 96 self.logger = logging.getLogger(f"InteractorLogger_{id(self)}") 97 self.logger.setLevel(logging.DEBUG) 98 self.providers = { 99 "openai": { 100 "sdk": "openai", 101 "base_url": "https://api.openai.com/v1", 102 "api_key": api_key or os.getenv("OPENAI_API_KEY") or None 103 }, 104 "ollama": { 105 "sdk": "openai", 106 "base_url": "http://localhost:11434/v1", 107 "api_key": api_key or "ollama" 108 }, 109 "nvidia": { 110 "sdk": "openai", 111 "base_url": "https://integrate.api.nvidia.com/v1", 112 "api_key": api_key or os.getenv("NVIDIA_API_KEY") or None 113 }, 114 "google": { 115 "sdk": "openai", 116 "base_url": "https://generativelanguage.googleapis.com/v1beta/openai", 117 "api_key": api_key or os.getenv("GEMINI_API_KEY") or None 118 }, 119 "anthropic": { 120 "sdk": "anthropic", 121 "base_url": "https://api.anthropic.com/v1", 122 "api_key": api_key or os.getenv("ANTHROPIC_API_KEY") or None 123 }, 124 "mistral": { 125 "sdk": "openai", 126 "base_url": "https://api.mistral.ai/v1", 127 "api_key": api_key or os.getenv("MISTRAL_API_KEY") or None 128 }, 129 "deepseek": { 130 "sdk": "openai", 131 "base_url": "https://api.deepseek.com", 132 "api_key": api_key or os.getenv("DEEPSEEK_API_KEY") or None 133 }, 134 "grok": { 135 "sdk": "grok", 136 "base_url": "https://api.x.ai/v1", 137 "api_key": api_key or os.getenv("GROK_API_KEY") or None 138 } 139 } 140 141 142 # Console log handler (always enabled at WARNING+) 143 if not self.logger.handlers: 144 console_handler = logging.StreamHandler(sys.stdout) 145 console_handler.setLevel(logging.WARNING) 146 console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) 147 self.logger.addHandler(console_handler) 148 149 self._log_enabled = False 150 if log_path: 151 file_handler = logging.FileHandler(log_path) 152 file_handler.setLevel(logging.DEBUG) 153 file_handler.setFormatter(logging.Formatter( 154 "%(asctime)s - %(levelname)s - %(message)s", 155 datefmt="%Y-%m-%d %H:%M:%S" 156 )) 157 self.logger.addHandler(file_handler) 158 self._log_enabled = True 159 160 161 self.token_estimate = 0 162 self.last_token_estimate = 0 163 self.stream = stream 164 self.tools = [] 165 self.session_history = [] 166 self.history = [] 167 self.context_length = context_length 168 self.encoding = None 169 self.max_retries = max_retries 170 self.retry_delay = retry_delay 171 self.reveal_tool = [] 172 self.fallback_model = fallback_model 173 self.sdk = None 174 175 # Session support 176 self.session_enabled = session_enabled 177 self.session_id = session_id 178 self._last_session_id = session_id 179 self.session = Session(directory=session_path) if session_enabled else None 180 181 182 if model is None: 183 model = "openai:gpt-4o-mini" 184 185 # Initialize model + encoding 186 self._setup_client(model, base_url, api_key) 187 self.tools_enabled = self.tools_supported if tools is None else tools and self.tools_supported 188 self._setup_encoding() 189 self.messages_add(role="system", content=self.system) 190 191 def _log(self, message: str, level: str = "info"): 192 if self._log_enabled: 193 getattr(self.logger, level)(message) 194 195 def _setup_client( 196 self, 197 model: Optional[str] = None, 198 base_url: Optional[str] = None, 199 api_key: Optional[str] = None 200 ): 201 """Initialize or reconfigure the Interactor for the given model and SDK. 202 203 Ensures idempotent setup, assigns SDK-specific clients and tool handling logic, 204 and normalizes history to match the provider-specific message schema. 205 """ 206 if not model: 207 raise ValueError("Model must be specified as 'provider:model_name'") 208 209 provider, model_name = model.split(":", 1) 210 211 if not hasattr(self, "session_history"): 212 self.session_history = [] 213 214 # Skip setup if nothing has changed (client may not yet exist on first call) 215 if ( 216 hasattr(self, "client") 217 and self.client 218 and self.provider == provider 219 and self.model == model_name 220 and self.base_url == (base_url or self.base_url) 221 ): 222 return 223 224 if provider not in self.providers: 225 raise ValueError(f"Unsupported provider: {provider}. Supported: {list(self.providers.keys())}") 226 227 # Load provider configuration 228 provider_config = self.providers[provider] 229 self.sdk = provider_config.get("sdk", "openai") 230 self.provider = provider 231 self.model = model_name 232 self.base_url = base_url or provider_config["base_url"] 233 effective_api_key = api_key or provider_config["api_key"] 234 235 if not effective_api_key and provider != "ollama": 236 raise ValueError(f"API key not provided and not found in environment for {provider.upper()}_API_KEY") 237 238 # SDK-specific configuration 239 if self.sdk == "openai": 240 self.client = openai.OpenAI(base_url=self.base_url, api_key=effective_api_key) 241 self.async_client = openai.AsyncOpenAI(base_url=self.base_url, api_key=effective_api_key) 242 self.sdk_runner = self._openai_runner 243 self.tool_key = "tool_call_id" 244 245 elif self.sdk == "anthropic": 246 self.client = anthropic.Anthropic(api_key=effective_api_key) 247 self.async_client = anthropic.AsyncAnthropic(api_key=effective_api_key) 248 self.sdk_runner = self._anthropic_runner 249 self.tool_key = "tool_use_id" 250 251 else: 252 raise ValueError(f"Unsupported SDK type: {self.sdk}") 253 254 # Determine tool support 255 self.tools_supported = self._check_tool_support() 256 if not self.tools_supported: 257 self.logger.warning(f"Tool calling not supported for {provider}:{model_name}") 258 259 # Normalize session history to match SDK after any provider/model change 260 self._normalizer(force=True) 261 262 self._log(f"[MODEL] Switched to {provider}:{model_name}") 263 264 def _check_tool_support(self) -> bool: 265 """Determine if the current model supports tool calling. 266 267 Returns: 268 bool: True if tools are supported for the active provider/model, False otherwise. 269 """ 270 try: 271 if self.sdk == "openai": 272 response = self.client.chat.completions.create( 273 model=self.model, 274 messages=[{"role": "user", "content": "Test tool support."}], 275 stream=False, 276 tools=[{ 277 "type": "function", 278 "function": { 279 "name": "test_tool", 280 "description": "Check tool support", 281 "parameters": { 282 "type": "object", 283 "properties": { 284 "query": {"type": "string"} 285 }, 286 "required": ["query"] 287 } 288 } 289 }], 290 tool_choice="auto" 291 ) 292 message = response.choices[0].message 293 return bool(message.tool_calls and len(message.tool_calls) > 0) 294 295 elif self.sdk == "anthropic": 296 # For Claude models, we pre-define support based on model ID 297 # Known tool-supporting Claude models 298 claude_tool_models = ["claude-3-opus", "claude-3-sonnet", "claude-3-haiku", 299 "claude-3.5-sonnet", "claude-3.7-sonnet"] 300 301 # Check if the current model supports tools 302 for supported_model in claude_tool_models: 303 if supported_model in self.model.lower(): 304 self._log(f"[TOOLS] Anthropic model {self.model} is known to support tools") 305 return True 306 307 # If not explicitly supported, try to test 308 try: 309 _ = self.client.messages.create( 310 model=self.model, 311 messages=[{"role": "user", "content": "What's the weather?"}], 312 tools=[{ 313 "name": "test_tool", 314 "description": "Check tool support", 315 "input_schema": { 316 "type": "object", 317 "properties": { 318 "query": {"type": "string"} 319 }, 320 "required": ["query"] 321 } 322 }], 323 max_tokens=10 324 ) 325 return True 326 except anthropic.BadRequestError as e: 327 error_msg = str(e).lower() 328 if "tool" in error_msg and "not supported" in error_msg: 329 self._log(f"[TOOLS] Anthropic model {self.model} does not support tools: {e}") 330 return False 331 if "not a supported tool field" in error_msg: 332 self._log(f"[TOOLS] Anthropic API rejected tool format: {e}") 333 return False 334 raise 335 except Exception as e: 336 self._log(f"[TOOLS] Unexpected error testing tool support: {e}", level="error") 337 return False 338 339 else: 340 self.logger.warning(f"Tool support check not implemented for SDK '{self.sdk}'") 341 return False 342 343 except Exception as e: 344 self.logger.error(f"Tool support check failed for {self.provider}:{self.model} — {e}") 345 return False 346 347 def add_function( 348 self, 349 external_callable: Callable, 350 name: Optional[str] = None, 351 description: Optional[str] = None, 352 override: bool = False, 353 disabled: bool = False, 354 schema_extensions: Optional[Dict[str, Any]] = None 355 ): 356 """ 357 Register a function for LLM tool calling with full type hints and metadata. 358 359 Args: 360 external_callable (Callable): The function to register. 361 name (Optional[str]): Optional custom name. Defaults to function's __name__. 362 description (Optional[str]): Optional custom description. Defaults to first line of docstring. 363 override (bool): If True, replaces an existing tool with the same name. 364 disabled (bool): If True, registers the function in a disabled state. 365 schema_extensions (Optional[Dict[str, Any]]): Optional dictionary mapping parameter names to 366 schema extensions that override or add to the auto-generated schema. 367 368 Raises: 369 ValueError: If the callable is invalid or duplicate name found without override. 370 371 Example: 372 interactor.add_function( 373 my_tool, 374 override=True, 375 disabled=False, 376 schema_extensions={ 377 "param1": {"minimum": 0, "maximum": 100}, 378 "param2": {"format": "email"} 379 } 380 ) 381 """ 382 def _python_type_to_schema(ptype: Any) -> dict: 383 """Convert a Python type annotation to OpenAI-compatible JSON Schema.""" 384 # Handle None case 385 if ptype is None: 386 return {"type": "null"} 387 388 # Get the origin and arguments of the type 389 origin = get_origin(ptype) 390 args = get_args(ptype) 391 392 # Handle Union types (including Optional) 393 if origin is Union: 394 # Check for Optional (Union with None) 395 none_type = type(None) 396 if none_type in args: 397 non_none = [a for a in args if a is not none_type] 398 if len(non_none) == 1: 399 inner = _python_type_to_schema(non_none[0]) 400 inner_copy = inner.copy() 401 inner_copy["nullable"] = True 402 return inner_copy 403 # Multiple types excluding None 404 types = [_python_type_to_schema(a) for a in non_none] 405 return {"anyOf": types, "nullable": True} 406 # Regular Union without None 407 return {"anyOf": [_python_type_to_schema(a) for a in args]} 408 409 # Handle List and similar container types 410 if origin in (list, List): 411 item_type = args[0] if args else Any 412 if item_type is Any: 413 return {"type": "array"} 414 return {"type": "array", "items": _python_type_to_schema(item_type)} 415 416 # Handle Dict types with typing info 417 if origin in (dict, Dict): 418 if not args or len(args) != 2: 419 return {"type": "object"} 420 421 key_type, val_type = args 422 # We can only really use val_type in JSON Schema 423 if val_type is not Any and val_type is not object: 424 return { 425 "type": "object", 426 "additionalProperties": _python_type_to_schema(val_type) 427 } 428 return {"type": "object"} 429 430 # Handle Literal types for enums 431 if origin is Literal: 432 values = args 433 # Try to determine type from values 434 if all(isinstance(v, str) for v in values): 435 return {"type": "string", "enum": list(values)} 436 elif all(isinstance(v, bool) for v in values): 437 return {"type": "boolean", "enum": list(values)} 438 elif all(isinstance(v, (int, float)) for v in values): 439 return {"type": "number", "enum": list(values)} 440 else: 441 # Mixed types, use anyOf 442 return {"anyOf": [{"type": _get_json_type(v), "enum": [v]} for v in values]} 443 444 # Handle basic types 445 if ptype is str: 446 return {"type": "string"} 447 if ptype is int: 448 return {"type": "integer"} 449 if ptype is float: 450 return {"type": "number"} 451 if ptype is bool: 452 return {"type": "boolean"} 453 454 # Handle common datetime types 455 if ptype is datetime: 456 return {"type": "string", "format": "date-time"} 457 if ptype is date: 458 return {"type": "string", "format": "date"} 459 460 # Handle UUID 461 if ptype is uuid.UUID: 462 return {"type": "string", "format": "uuid"} 463 464 # Default to object for any other types 465 return {"type": "object"} 466 467 def _get_json_type(value): 468 """Get the JSON Schema type name for a Python value.""" 469 if isinstance(value, str): 470 return "string" 471 elif isinstance(value, bool): 472 return "boolean" 473 elif isinstance(value, int) or isinstance(value, float): 474 return "number" 475 elif isinstance(value, list): 476 return "array" 477 elif isinstance(value, dict): 478 return "object" 479 else: 480 return "object" # Default 481 482 def _parse_param_docs(docstring: str) -> dict: 483 """Extract parameter descriptions from a docstring.""" 484 if not docstring: 485 return {} 486 487 lines = docstring.splitlines() 488 param_docs = {} 489 current_param = None 490 in_params = False 491 492 # Regular expressions for finding parameter sections and param lines 493 param_section_re = re.compile(r"^(Args|Parameters):\s*$") 494 param_line_re = re.compile(r"^\s{4}(\w+)\s*(?:\([^\)]*\))?:\s*(.*)") 495 496 for line in lines: 497 # Check if we're entering the parameters section 498 if param_section_re.match(line.strip()): 499 in_params = True 500 continue 501 502 if in_params: 503 # Skip empty lines 504 if not line.strip(): 505 continue 506 507 # Check for a parameter definition line 508 match = param_line_re.match(line) 509 if match: 510 current_param = match.group(1) 511 param_docs[current_param] = match.group(2).strip() 512 # Check for continuation of a parameter description 513 elif current_param and line.startswith(" " * 8): 514 param_docs[current_param] += " " + line.strip() 515 # If we see a line that doesn't match our patterns, we're out of the params section 516 else: 517 current_param = None 518 519 return param_docs 520 521 # Start of main function logic 522 523 # Skip if tools are disabled 524 if not self.tools_enabled: 525 return 526 527 # Validate input callable 528 if not external_callable: 529 raise ValueError("A valid external callable must be provided.") 530 531 # Set function name, either from parameter or from callable's __name__ 532 function_name = name or external_callable.__name__ 533 534 # Try to get docstring and extract description 535 try: 536 docstring = inspect.getdoc(external_callable) 537 description = description or (docstring.split("\n")[0].strip() if docstring else "No description provided.") 538 except Exception as e: 539 self._log(f"[TOOL] Warning: Could not extract docstring from {function_name}: {e}", level="warning") 540 docstring = "" 541 description = description or "No description provided." 542 543 # Extract parameter documentation from docstring 544 param_docs = _parse_param_docs(docstring) 545 546 # Handle conflicts with existing functions 547 if override: 548 self.delete_function(function_name) 549 elif any(t["function"]["name"] == function_name for t in self.tools): 550 raise ValueError(f"Function '{function_name}' is already registered. Use override=True to replace.") 551 552 # Try to get function signature for parameter info 553 try: 554 signature = inspect.signature(external_callable) 555 except (ValueError, TypeError) as e: 556 raise ValueError(f"Cannot inspect callable '{function_name}': {e}") 557 558 # Process parameters to build schema 559 properties = {} 560 required = [] 561 562 for param_name, param in signature.parameters.items(): 563 # Skip self, cls parameters for instance/class methods 564 if param_name in ("self", "cls") and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: 565 continue 566 567 # Get parameter annotation, defaulting to Any 568 annotation = param.annotation if param.annotation != inspect.Parameter.empty else Any 569 570 try: 571 # Convert Python type to JSON Schema 572 schema = _python_type_to_schema(annotation) 573 574 # Add description from docstring or create a default one 575 schema["description"] = param_docs.get(param_name, f"{param_name} parameter") 576 577 # Add to properties 578 properties[param_name] = schema 579 580 # If no default, add to required list 581 if param.default == inspect.Parameter.empty: 582 required.append(param_name) 583 584 except Exception as e: 585 self._log(f"[TOOL] Error processing parameter {param_name} for {function_name}: {e}", level="error") 586 # Add a basic object schema as fallback 587 properties[param_name] = { 588 "type": "object", 589 "description": f"{param_name} parameter (type conversion failed)" 590 } 591 592 # Apply schema extensions if provided 593 if schema_extensions: 594 for param_name, extensions in schema_extensions.items(): 595 if param_name in properties: 596 properties[param_name].update(extensions) 597 598 # Build the final tool specification 599 tool_spec = { 600 "type": "function", 601 "function": { 602 "name": function_name, 603 "description": description, 604 "parameters": { 605 "type": "object", 606 "properties": properties, 607 "required": required 608 } 609 } 610 } 611 612 # Set disabled flag if requested 613 if disabled: 614 tool_spec["function"]["disabled"] = True 615 616 # Add to tools list 617 self.tools.append(tool_spec) 618 619 # Make the function available as an attribute on the instance 620 setattr(self, function_name, external_callable) 621 622 # Log the registration 623 self._log(f"[TOOL] Registered function '{function_name}' with {len(properties)} parameters", level="info") 624 625 return function_name # Return the name for reference 626 627 628 def disable_function(self, name: str) -> bool: 629 """ 630 Disable a registered tool function by name. 631 632 This marks the function as inactive for tool calling without removing it from the internal registry. 633 The function remains visible in the tool listing but is skipped during tool selection by the LLM. 634 635 Args: 636 name (str): The name of the function to disable. 637 638 Returns: 639 bool: True if the function was found and disabled, False otherwise. 640 641 Example: 642 interactor.disable_function("extract_text") 643 """ 644 for tool in self.tools: 645 if tool["function"]["name"] == name: 646 tool["function"]["disabled"] = True 647 return True 648 return False 649 650 651 def enable_function(self, name: str) -> bool: 652 """ 653 Re-enable a previously disabled tool function by name. 654 655 This removes the 'disabled' flag from a tool function, making it available again for LLM use. 656 657 Args: 658 name (str): The name of the function to enable. 659 660 Returns: 661 bool: True if the function was found and enabled, False otherwise. 662 663 Example: 664 interactor.enable_function("extract_text") 665 """ 666 for tool in self.tools: 667 if tool["function"]["name"] == name: 668 tool["function"].pop("disabled", None) 669 return True 670 return False 671 672 673 def delete_function(self, name: str) -> bool: 674 """ 675 Permanently remove a registered tool function from the Interactor. 676 677 This deletes both the tool metadata and the callable attribute, making it fully inaccessible 678 from the active session. Useful for dynamically trimming the toolset. 679 680 Args: 681 name (str): The name of the function to delete. 682 683 Returns: 684 bool: True if the function was found and removed, False otherwise. 685 686 Example: 687 interactor.delete_function("extract_text") 688 """ 689 before = len(self.tools) 690 self.tools = [tool for tool in self.tools if tool["function"]["name"] != name] 691 if hasattr(self, name): 692 delattr(self, name) 693 return len(self.tools) < before 694 695 696 def list_functions(self) -> List[Dict[str, Any]]: 697 """Get the list of registered functions for tool calling. 698 699 Returns: 700 List[Dict[str, Any]]: List of registered functions. 701 """ 702 return self.tools 703 704 def list_models( 705 self, 706 providers: Optional[Union[str, List[str]]] = None, 707 filter: Optional[str] = None 708 ) -> List[str]: 709 """Retrieve available models from configured providers. 710 711 Args: 712 providers: Provider name or list of provider names. If None, all are queried. 713 filter: Optional regex to filter model names. 714 715 Returns: 716 List[str]: Sorted list of "provider:model_id" strings. 717 """ 718 models = [] 719 720 if providers is None: 721 providers_to_list = self.providers 722 elif isinstance(providers, str): 723 providers_to_list = {providers: self.providers.get(providers)} 724 elif isinstance(providers, list): 725 providers_to_list = {p: self.providers.get(p) for p in providers} 726 else: 727 return [] 728 729 invalid_providers = [p for p in providers_to_list if p not in self.providers or self.providers[p] is None] 730 if invalid_providers: 731 self.logger.error(f"Invalid providers: {invalid_providers}") 732 return [] 733 734 regex_pattern = None 735 if filter: 736 try: 737 regex_pattern = re.compile(filter, re.IGNORECASE) 738 except re.error as e: 739 self.logger.error(f"Invalid regex pattern: {e}") 740 return [] 741 742 for provider_name, config in providers_to_list.items(): 743 sdk = config.get("sdk", "openai") 744 base_url = config.get("base_url") 745 api_key = config.get("api_key") 746 747 try: 748 if sdk == "openai": 749 client = openai.OpenAI(api_key=api_key, base_url=base_url) 750 response = client.models.list() 751 for model in response.data: 752 model_id = f"{provider_name}:{model.id}" 753 if not regex_pattern or regex_pattern.search(model_id): 754 models.append(model_id) 755 756 elif sdk == "anthropic": 757 client = Anthropic(api_key=api_key) 758 response = client.models.list() 759 for model in response: 760 model_id = f"{provider_name}:{model.id}" 761 if not regex_pattern or regex_pattern.search(model_id): 762 models.append(model_id) 763 else: 764 self.logger.warning(f"SDK '{sdk}' for provider '{provider_name}' is not supported by list_models()") 765 766 except Exception as e: 767 self.logger.error(f"Failed to list models for {provider_name}: {e}") 768 769 return sorted(models, key=str.lower) 770 771 772 async def _retry_with_backoff(self, func: Callable, *args, **kwargs): 773 for attempt in range(self.max_retries + 1): 774 try: 775 return await func(*args, **kwargs) 776 777 except (RateLimitError, APIConnectionError, aiohttp.ClientError) as e: 778 if attempt == self.max_retries: 779 model_key = f"{self.provider}:{self.model}" 780 if self.fallback_model and model_key != self.fallback_model: 781 print(f"[yellow]Model '{model_key}' failed. Switching to fallback: {self.fallback_model}[/yellow]") 782 self._setup_client(self.fallback_model) 783 self._setup_encoding() 784 self._normalizer() 785 return await func(*args, **kwargs) # retry once with fallback model 786 else: 787 self.logger.error(f"All {self.max_retries} retries failed: {e}") 788 raise 789 790 delay = self.retry_delay * (2 ** attempt) 791 self.logger.warning(f"Retry {attempt + 1}/{self.max_retries} after {delay}s due to {e}") 792 self._log(f"[RETRY] Attempt {attempt + 1}/{self.max_retries} failed: {e}", level="warning") 793 await asyncio.sleep(delay) 794 795 except OpenAIError as e: 796 self.logger.error(f"OpenAI error: {e}") 797 raise 798 799 except Exception as e: 800 self.logger.error(f"Unexpected error: {e}") 801 raise 802 803 def interact( 804 self, 805 user_input: Optional[str], 806 quiet: bool = False, 807 tools: bool = True, 808 stream: bool = True, 809 markdown: bool = False, 810 model: Optional[str] = None, 811 output_callback: Optional[Callable[[str], None]] = None, 812 session_id: Optional[str] = None 813 ) -> Optional[str]: 814 """Main universal gateway for all LLM interaction.""" 815 if not user_input: 816 return None 817 818 # Setup model if specified 819 if model: 820 self._setup_client(model) 821 self._setup_encoding() 822 823 # Session handling 824 if self.session_enabled and session_id: 825 self.session_id = session_id 826 self.session_load(session_id) 827 828 # Add user message using messages_add 829 self.messages_add(role="user", content=user_input) 830 831 # Log token count estimate 832 token_count = self._count_tokens(self.history) 833 if not quiet: 834 print(f"[dim]Estimated tokens in context: {token_count} / {self.context_length}[/dim]") 835 836 # Make sure we have enough context space 837 if token_count > self.context_length: 838 if self._cycle_messages(): 839 if not quiet: 840 print("[red]Context window exceeded. Cannot proceed.[/red]") 841 return None 842 843 # Log user input 844 self._log(f"[USER] {user_input}") 845 846 # Handle the actual interaction with complete streaming for all responses 847 result = asyncio.run(self._interact_async_core( 848 user_input=user_input, 849 quiet=quiet, 850 tools=tools, 851 stream=stream, 852 markdown=markdown, 853 output_callback=output_callback 854 )) 855 856 # Log completion for this interaction 857 self._log(f"[INTERACTION] Completed with {len(self.history)} total messages") 858 859 return result 860 861 async def _interact_async_core( 862 self, 863 user_input: str, 864 quiet: bool = False, 865 tools: bool = True, 866 stream: bool = True, 867 markdown: bool = False, 868 output_callback: Optional[Callable] = None 869 ) -> str: 870 """Main SDK-agnostic async execution pipeline with tool call looping support.""" 871 # Prepare display handler 872 live = Live(console=console, refresh_per_second=100) if markdown and stream else None 873 if live: 874 live.start() 875 876 # Initialize variables for iteration tracking 877 full_content = "" 878 tool_enabled = self.tools_enabled and self.tools_supported and tools 879 max_iterations = 5 # Prevent infinite loops 880 iterations = 0 881 882 # Main interaction loop - continues until no more tool calls or max iterations reached 883 while iterations < max_iterations: 884 iterations += 1 885 886 try: 887 # Execute the appropriate SDK runner - history is already normalized 888 response_data = await self.sdk_runner( 889 model=self.model, 890 messages=self.history, 891 stream=stream, 892 markdown=markdown, 893 quiet=quiet if iterations == 1 else False, 894 live=live, 895 output_callback=output_callback 896 ) 897 898 # Extract response data 899 content = response_data.get("content", "") 900 tool_calls = response_data.get("tool_calls", []) 901 902 # Log the response data for debugging 903 self._log(f"[ITERATION {iterations}] Content: {len(content)} chars, Tool calls: {len(tool_calls)}") 904 905 # Add content to full response 906 if iterations == 1: 907 full_content = content 908 elif content: 909 if full_content and content: 910 full_content += f"\n{content}" 911 else: 912 full_content = content 913 914 # Add assistant message with or without tool calls 915 if tool_calls: 916 # Process each tool call 917 for call in tool_calls: 918 # Add assistant message with tool call 919 tool_info = { 920 "id": call["id"], 921 "name": call["function"]["name"], 922 "arguments": call["function"]["arguments"] 923 } 924 925 # Add the assistant message with tool call 926 self.messages_add( 927 role="assistant", 928 content=content if len(tool_calls) == 1 else "", 929 tool_info=tool_info 930 ) 931 932 # Execute the tool 933 call_name = call["function"]["name"] 934 call_args = call["function"]["arguments"] 935 call_id = call["id"] 936 937 result = await self._handle_tool_call_async( 938 call_name, call_args, call_id, 939 {"model": self.model, "messages": self.history, "stream": stream}, 940 markdown, live, False, output_callback 941 ) 942 943 # Add tool result message 944 tool_result_info = { 945 "id": call_id, 946 "result": result 947 } 948 949 self.messages_add( 950 role="tool", 951 content=result, 952 tool_info=tool_result_info 953 ) 954 else: 955 # Simple assistant response without tool calls 956 self.messages_add(role="assistant", content=content) 957 break # No more tools to process, we're done 958 959 # Reset live display if needed 960 if stream and live: 961 live.stop() 962 live = Live(console=console, refresh_per_second=100) 963 live.start() 964 965 except Exception as e: 966 self.logger.error(f"[{self.sdk.upper()} ERROR] {str(e)}") 967 self._log(f"[ERROR] Error in interaction loop: {str(e)}", level="error") 968 if live: 969 live.stop() 970 return f"Error: {str(e)}" 971 972 # Clean up display 973 if live: 974 live.stop() 975 976 return full_content or "No response." 977 978 async def _openai_runner( 979 self, 980 *, 981 model, 982 messages, 983 stream, 984 markdown=False, 985 quiet=False, 986 live=None, 987 output_callback=None 988 ): 989 """Handle OpenAI-specific API interactions and response processing.""" 990 # Log what we're sending for debugging 991 self._log(f"[OPENAI REQUEST] Sending request to {model} with {len(self.history)} messages", level="debug") 992 993 # Prepare API parameters - history is already normalized by _normalizer 994 params = { 995 "model": model, 996 "messages": self.history, 997 "stream": stream, 998 } 999 1000 # Add tools if enabled 1001 if self.tools_enabled and self.tools_supported: 1002 enabled_tools = self._get_enabled_tools() 1003 if enabled_tools: 1004 params["tools"] = enabled_tools 1005 params["tool_choice"] = "auto" 1006 1007 # Call API with retry handling 1008 try: 1009 response = await self._retry_with_backoff( 1010 self.async_client.chat.completions.create, 1011 **params 1012 ) 1013 except Exception as e: 1014 self.logger.error(f"[OPENAI ERROR] {str(e)}") 1015 raise 1016 1017 assistant_content = "" 1018 tool_calls_dict = {} 1019 1020 # Process streaming response 1021 if stream and hasattr(response, "__aiter__"): 1022 async for chunk in response: 1023 delta = getattr(chunk.choices[0], "delta", None) 1024 finish_reason = getattr(chunk.choices[0], "finish_reason", None) 1025 1026 # Handle content chunks 1027 if hasattr(delta, "content") and delta.content is not None: 1028 text = delta.content 1029 assistant_content += text 1030 if output_callback: 1031 output_callback(text) 1032 elif live: 1033 live.update(Markdown(assistant_content)) 1034 elif not markdown and not quiet: 1035 print(text, end="") 1036 1037 # Process tool calls 1038 if hasattr(delta, "tool_calls") and delta.tool_calls: 1039 for tool_call_delta in delta.tool_calls: 1040 index = tool_call_delta.index 1041 if index not in tool_calls_dict: 1042 tool_calls_dict[index] = { 1043 "id": tool_call_delta.id if hasattr(tool_call_delta, "id") else None, 1044 "function": {"name": "", "arguments": ""} 1045 } 1046 1047 function = getattr(tool_call_delta, "function", None) 1048 if function: 1049 name = getattr(function, "name", None) 1050 args = getattr(function, "arguments", "") 1051 if name: 1052 tool_calls_dict[index]["function"]["name"] = name 1053 if args: 1054 tool_calls_dict[index]["function"]["arguments"] += args 1055 if tool_call_delta.id and not tool_calls_dict[index]["id"]: 1056 tool_calls_dict[index]["id"] = tool_call_delta.id 1057 1058 # Make sure the ID is set regardless 1059 if hasattr(tool_call_delta, "id") and tool_call_delta.id and not tool_calls_dict[index]["id"]: 1060 tool_calls_dict[index]["id"] = tool_call_delta.id 1061 1062 if not output_callback and not markdown and not quiet: 1063 print() 1064 1065 # Process non-streaming response 1066 else: 1067 message = response.choices[0].message 1068 assistant_content = message.content or "" 1069 1070 if hasattr(message, "tool_calls") and message.tool_calls: 1071 for i, tool_call in enumerate(message.tool_calls): 1072 tool_calls_dict[i] = { 1073 "id": tool_call.id, 1074 "function": { 1075 "name": tool_call.function.name, 1076 "arguments": tool_call.function.arguments 1077 } 1078 } 1079 1080 if output_callback: 1081 output_callback(assistant_content) 1082 elif not quiet: 1083 print(assistant_content) 1084 1085 # Log tool calls for debugging 1086 if tool_calls_dict: 1087 self._log(f"[OPENAI TOOL CALLS] Found {len(tool_calls_dict)} tool calls", level="debug") 1088 for idx, call in tool_calls_dict.items(): 1089 self._log(f"[OPENAI TOOL CALL {idx}] {call['function']['name']} with ID {call['id']}", level="debug") 1090 1091 # Return standardized response format 1092 return { 1093 "content": assistant_content, 1094 "tool_calls": list(tool_calls_dict.values()) 1095 } 1096 1097 async def _anthropic_runner( 1098 self, 1099 *, 1100 model, 1101 messages, 1102 stream, 1103 markdown=False, 1104 quiet=False, 1105 live=None, 1106 output_callback=None 1107 ): 1108 """Handle Anthropic-specific API interactions and response processing.""" 1109 # Log what we're sending for debugging 1110 self._log(f"[ANTHROPIC REQUEST] Sending request to {model} with {len(self.history)} messages", level="debug") 1111 1112 # Prepare API parameters - history is already normalized by _normalizer 1113 params = { 1114 "model": model, 1115 "messages": self.history, 1116 "max_tokens": 8192, 1117 "system": self.system 1118 } 1119 1120 # Add tools support if needed 1121 if self.tools_enabled and self.tools_supported: 1122 enabled_tools = [] 1123 for tool in self._get_enabled_tools(): 1124 format_tool = { 1125 "name": tool["function"]["name"], 1126 "description": tool["function"].get("description", ""), 1127 "input_schema": tool["function"]["parameters"], 1128 } 1129 enabled_tools.append(format_tool) 1130 1131 params["tools"] = enabled_tools 1132 1133 assistant_content = "" 1134 tool_calls_dict = {} 1135 1136 try: 1137 # Process streaming response 1138 if stream: 1139 stream_response = await self._retry_with_backoff( 1140 self.async_client.messages.create, 1141 stream=True, 1142 **params 1143 ) 1144 1145 content_type = None 1146 async for chunk in stream_response: 1147 data = json.loads(chunk.to_json()) 1148 chunk_type = getattr(chunk, "type", "unknown") 1149 stop_reason = getattr(chunk, "stop_reason", None) 1150 self._log(f"[ANTHROPIC CHUNK] Type: {chunk_type}", level="debug") 1151 if chunk_type == "content_block_start" and hasattr(chunk.content_block, "type"): 1152 content_type = chunk.content_block.type 1153 if content_type == "tool_use": 1154 tool_id = chunk.content_block.id 1155 tool_name = chunk.content_block.name 1156 tool_input = chunk.content_block.input 1157 tool_calls_dict[tool_id] = { 1158 "id": tool_id, 1159 "function": { 1160 "name": tool_name, 1161 "arguments": "" 1162 } 1163 } 1164 self._log(f"[ANTHROPIC TOOL USE] Found complete tool use: {tool_name}", level="info") 1165 1166 # Handle text content 1167 if chunk_type == "content_block_delta" and hasattr(chunk.delta, "text"): 1168 delta = chunk.delta.text 1169 assistant_content += delta 1170 if output_callback: 1171 output_callback(delta) 1172 elif live: 1173 live.update(Markdown(assistant_content)) 1174 elif not markdown and not quiet: 1175 print(delta, end="") 1176 1177 # Handle complete tool use 1178 elif chunk_type == "content_block_delta" and content_type == "tool_use": 1179 tool_calls_dict[tool_id]["function"]["arguments"] += chunk.delta.partial_json 1180 1181 # Process non-streaming response 1182 else: 1183 response = await self._retry_with_backoff( 1184 self.async_client.messages.create, 1185 **params 1186 ) 1187 1188 # Extract text content 1189 for content_block in response.content: 1190 if content_block.type == "text": 1191 assistant_content += content_block.text 1192 1193 # Extract tool uses 1194 tool_uses = getattr(response, "tool_uses", []) 1195 if tool_uses: 1196 self._log(f"[ANTHROPIC TOOL USES] Found {len(tool_uses)} tool uses", level="debug") 1197 for i, tool_use in enumerate(tool_uses): 1198 # Extract and format the tool use 1199 tool_id = tool_use.id 1200 tool_name = tool_use.name 1201 tool_input = tool_use.input 1202 1203 # Format the input as JSON string 1204 if isinstance(tool_input, dict): 1205 input_json = json.dumps(tool_input) 1206 else: 1207 input_json = json.dumps({}) if tool_input is None else str(tool_input) 1208 1209 tool_calls_dict[tool_id] = { 1210 "id": tool_id, 1211 "function": { 1212 "name": tool_name, 1213 "arguments": input_json 1214 } 1215 } 1216 self._log(f"[ANTHROPIC TOOL USE] {tool_name}", level="debug") 1217 1218 if output_callback: 1219 output_callback(assistant_content) 1220 elif not quiet: 1221 print(assistant_content) 1222 1223 except Exception as e: 1224 self.logger.error(f"Error in Anthropic runner: {traceback.print_exc()}") 1225 # Add more detailed logging 1226 self._log(f"[ANTHROPIC ERROR] {traceback.format_exc()}", level="error") 1227 1228 # Return something usable even in case of error 1229 return { 1230 "content": f"Error processing Anthropic response: {str(e)}", 1231 "tool_calls": [] 1232 } 1233 1234 # Return standardized response format 1235 return { 1236 "content": assistant_content, 1237 "tool_calls": list(tool_calls_dict.values()) 1238 } 1239 1240 1241 1242 def _get_enabled_tools(self) -> List[dict]: 1243 """Return the list of currently enabled tool function definitions.""" 1244 return [ 1245 tool for tool in self.tools 1246 if not tool["function"].get("disabled", False) 1247 ] 1248 1249 async def _handle_tool_call_async( 1250 self, 1251 function_name: str, 1252 function_arguments: str, 1253 tool_call_id: str, 1254 params: dict, 1255 markdown: bool, 1256 live: Optional[Live], 1257 safe: bool = False, 1258 output_callback: Optional[Callable[[str], None]] = None 1259 ) -> str: 1260 """Process a tool call asynchronously and return the result. 1261 1262 Args: 1263 function_name: Name of the function to call. 1264 function_arguments: JSON string containing the function arguments. 1265 tool_call_id: Unique identifier for this tool call. 1266 params: Parameters used for the original API call. 1267 markdown: If True, renders content as markdown. 1268 live: Optional Live context for updating content in real-time. 1269 safe: If True, prompts for confirmation before executing the tool call. 1270 output_callback: Optional callback to handle the tool call result. 1271 1272 Returns: 1273 The result of the function call. 1274 1275 Raises: 1276 ValueError: If the function is not found or JSON is invalid. 1277 """ 1278 live_was_active = False 1279 1280 try: 1281 arguments = json.loads(function_arguments or "{}") 1282 except json.JSONDecodeError as e: 1283 self.logger.error(f"Invalid JSON in tool call arguments: {e}") 1284 return {"error": "Invalid JSON format in tool call arguments."} 1285 1286 self._log(f"[TOOL:{function_name}] args={arguments}") 1287 1288 func = getattr(self, function_name, None) 1289 if not func: 1290 raise ValueError(f"Function '{function_name}' not found.") 1291 1292 # Properly handle live display if it's active 1293 if live and live.is_started: 1294 live_was_active = True 1295 live.stop() 1296 1297 print(f"\nRunning {function_name}...\n") 1298 1299 if output_callback: 1300 notification = json.dumps({ 1301 "type": "tool_call", 1302 "tool_name": function_name, 1303 "status": "started" 1304 }) 1305 output_callback(notification) 1306 1307 try: 1308 if safe: 1309 prompt = f"[bold yellow]Proposed tool call:[/bold yellow] {function_name}({json.dumps(arguments, indent=2)})\n[bold cyan]Execute? [y/n]: [/bold cyan]" 1310 confirmed = Confirm.ask(prompt, default=False) 1311 if not confirmed: 1312 command_result = { 1313 "status": "cancelled", 1314 "message": "Tool call aborted by user" 1315 } 1316 print("[red]Tool call cancelled by user[/red]") 1317 else: 1318 loop = asyncio.get_event_loop() 1319 command_result = await loop.run_in_executor(None, lambda: func(**arguments)) 1320 else: 1321 loop = asyncio.get_event_loop() 1322 command_result = await loop.run_in_executor(None, lambda: func(**arguments)) 1323 1324 try: 1325 json.dumps(command_result) 1326 except TypeError as e: 1327 self.logger.error(f"Tool call result not serializable: {e}") 1328 return {"error": "Tool call returned unserializable data."} 1329 1330 if output_callback: 1331 notification = json.dumps({ 1332 "type": "tool_call", 1333 "tool_name": function_name, 1334 "status": "completed" 1335 }) 1336 output_callback(notification) 1337 1338 # Restart live display if it was active before 1339 if live_was_active and live: 1340 live.start() 1341 1342 return command_result 1343 1344 except Exception as e: 1345 self._log(f"[ERROR] Tool execution failed: {e}", level="error") 1346 self.logger.error(f"Error executing tool function '{function_name}': {e}") 1347 return {"error": str(e)} 1348 1349 def _setup_encoding(self): 1350 """Set up the token encoding based on the current model.""" 1351 try: 1352 if self.provider == "openai": 1353 try: 1354 self.encoding = tiktoken.encoding_for_model(self.model) 1355 self._log(f"[ENCODING] Loaded tokenizer for OpenAI model: {self.model}") 1356 except: 1357 self.encoding = tiktoken.get_encoding("cl100k_base") 1358 self._log(f"[ENCODING] Fallback to cl100k_base for model: {self.model}") 1359 else: 1360 self.encoding = tiktoken.get_encoding("cl100k_base") 1361 self._log(f"[ENCODING] Defaulting to cl100k_base for non-OpenAI model: {self.model}") 1362 except Exception as e: 1363 self.logger.error(f"Failed to setup encoding: {e}") 1364 self.encoding = tiktoken.get_encoding("cl100k_base") 1365 1366 def _estimate_tokens_tiktoken(self, messages) -> int: 1367 """Rough token count estimate using tiktoken for OpenAI or fallback cases.""" 1368 if not hasattr(self, "encoding") or not self.encoding: 1369 self._setup_encoding() 1370 return sum(len(self.encoding.encode(msg.get("content", ""))) for msg in messages if isinstance(msg.get("content"), str)) 1371 1372 def _count_tokens(self, messages, use_cache=True) -> int: 1373 """Accurately estimate token count for messages including tool calls with caching support. 1374 1375 Args: 1376 messages: List of message objects in either OpenAI or Anthropic format. 1377 use_cache: Whether to use and update the token count cache. 1378 1379 Returns: 1380 int: Estimated token count. 1381 """ 1382 # Setup encoding if needed 1383 if not hasattr(self, "encoding") or not self.encoding: 1384 self._setup_encoding() 1385 1386 # Initialize cache if it doesn't exist 1387 if not hasattr(self, "_token_count_cache"): 1388 self._token_count_cache = {} 1389 1390 # Generate a cache key based on message content hashes 1391 if use_cache: 1392 try: 1393 # Create a cache key using message IDs if available, or content hashes 1394 cache_key_parts = [] 1395 for msg in messages: 1396 if isinstance(msg, dict): 1397 # Try to use stable identifiers for cache key 1398 msg_id = msg.get("id", None) 1399 timestamp = msg.get("timestamp", None) 1400 1401 if msg_id and timestamp: 1402 cache_key_parts.append(f"{msg_id}:{timestamp}") 1403 else: 1404 # Fall back to content-based hash if no stable IDs 1405 content_str = str(msg.get("content", "")) 1406 role = msg.get("role", "unknown") 1407 cache_key_parts.append(f"{role}:{hash(content_str)}") 1408 1409 cache_key = ":".join(cache_key_parts) 1410 if cache_key in self._token_count_cache: 1411 return self._token_count_cache[cache_key] 1412 except Exception as e: 1413 # If caching fails, just continue with normal counting 1414 self._log(f"[TOKEN COUNT] Cache key generation failed: {e}", level="debug") 1415 use_cache = False 1416 1417 # For Claude models, try to use their built-in token counter 1418 if self.sdk == "anthropic": 1419 try: 1420 # Convert messages to Anthropic format if needed 1421 anthropic_messages = [] 1422 for msg in messages: 1423 if msg.get("role") == "system": 1424 continue # System handled separately 1425 1426 if msg.get("role") == "tool": 1427 # Skip tool messages in token count to avoid double-counting 1428 continue 1429 1430 if msg.get("role") == "user" and isinstance(msg.get("content"), list): 1431 # Already in Anthropic format with tool_result 1432 anthropic_messages.append(msg) 1433 elif msg.get("role") in ["user", "assistant"]: 1434 if not msg.get("tool_calls") and not msg.get("tool_use"): 1435 # Simple message 1436 anthropic_messages.append({ 1437 "role": msg.get("role"), 1438 "content": msg.get("content", "") 1439 }) 1440 1441 # Use Anthropic's token counter if messages exist 1442 if anthropic_messages: 1443 response = self.client.messages.count_tokens( 1444 model=self.model, 1445 messages=anthropic_messages, 1446 system=self.system 1447 ) 1448 token_count = response.input_tokens 1449 1450 # Cache the result for future use 1451 if use_cache and 'cache_key' in locals(): 1452 self._token_count_cache[cache_key] = token_count 1453 1454 return token_count 1455 except Exception as e: 1456 # Fall back to our estimation 1457 self._log(f"[TOKEN COUNT] Error using Anthropic token counter: {e}", level="debug") 1458 1459 # More accurate token counting for all message types 1460 num_tokens = 0 1461 1462 # Count tokens for each message 1463 for msg in messages: 1464 # Base token count for message metadata (role + message format) 1465 num_tokens += 4 # Message overhead 1466 1467 # Add tokens for role name 1468 role = msg.get("role", "") 1469 num_tokens += len(self.encoding.encode(role)) 1470 1471 # Count tokens in message content 1472 if isinstance(msg.get("content"), str): 1473 content = msg.get("content", "") 1474 content_tokens = len(self.encoding.encode(content)) 1475 num_tokens += content_tokens 1476 1477 elif isinstance(msg.get("content"), list): 1478 # Handle Anthropic-style content lists 1479 for item in msg.get("content", []): 1480 if isinstance(item, dict): 1481 # Tool result or other structured content 1482 if item.get("type") == "tool_result": 1483 result_content = item.get("content", "") 1484 if isinstance(result_content, str): 1485 num_tokens += len(self.encoding.encode(result_content)) 1486 else: 1487 # JSON serialization for dict/list content 1488 num_tokens += len(self.encoding.encode(json.dumps(result_content))) 1489 # Add tokens for tool_use_id and type fields 1490 num_tokens += len(self.encoding.encode(item.get("type", ""))) 1491 num_tokens += len(self.encoding.encode(item.get("tool_use_id", ""))) 1492 1493 # Text content type 1494 elif item.get("type") == "text": 1495 num_tokens += len(self.encoding.encode(item.get("text", ""))) 1496 1497 # Tool use type 1498 elif item.get("type") == "tool_use": 1499 num_tokens += len(self.encoding.encode(item.get("name", ""))) 1500 tool_input = item.get("input", {}) 1501 if isinstance(tool_input, str): 1502 num_tokens += len(self.encoding.encode(tool_input)) 1503 else: 1504 num_tokens += len(self.encoding.encode(json.dumps(tool_input))) 1505 num_tokens += len(self.encoding.encode(item.get("id", ""))) 1506 else: 1507 # Plain text content 1508 num_tokens += len(self.encoding.encode(str(item))) 1509 1510 # Count tokens in tool calls for OpenAI format 1511 if msg.get("tool_calls"): 1512 for tool_call in msg.get("tool_calls", []): 1513 if isinstance(tool_call, dict): 1514 # Count tokens for function name 1515 func_name = tool_call.get("function", {}).get("name", "") 1516 num_tokens += len(self.encoding.encode(func_name)) 1517 1518 # Count tokens for arguments 1519 args = tool_call.get("function", {}).get("arguments", "") 1520 if isinstance(args, str): 1521 num_tokens += len(self.encoding.encode(args)) 1522 else: 1523 num_tokens += len(self.encoding.encode(json.dumps(args))) 1524 1525 # Add tokens for id and type fields 1526 num_tokens += len(self.encoding.encode(tool_call.get("id", ""))) 1527 num_tokens += len(self.encoding.encode(tool_call.get("type", "function"))) 1528 1529 # Count tokens in Anthropic tool_use field 1530 if msg.get("tool_use"): 1531 tool_use = msg.get("tool_use") 1532 # Count tokens for name 1533 num_tokens += len(self.encoding.encode(tool_use.get("name", ""))) 1534 1535 # Count tokens for input 1536 tool_input = tool_use.get("input", {}) 1537 if isinstance(tool_input, str): 1538 num_tokens += len(self.encoding.encode(tool_input)) 1539 else: 1540 num_tokens += len(self.encoding.encode(json.dumps(tool_input))) 1541 1542 # Add tokens for id field 1543 num_tokens += len(self.encoding.encode(tool_use.get("id", ""))) 1544 1545 # Handle tool response message format 1546 if msg.get("role") == "tool": 1547 # Add tokens for tool_call_id 1548 tool_id = msg.get("tool_call_id", "") 1549 num_tokens += len(self.encoding.encode(tool_id)) 1550 1551 # Add message end tokens 1552 num_tokens += 2 1553 1554 # Cache the result for future use 1555 if use_cache and 'cache_key' in locals(): 1556 self._token_count_cache[cache_key] = num_tokens 1557 1558 return num_tokens 1559 1560 def _cycle_messages(self): 1561 """ 1562 Intelligently trim the message history to fit within the allowed context length. 1563 1564 This method implements a sophisticated trimming strategy: 1565 1. Always preserves system messages 1566 2. Always keeps the most recent complete conversation turn 1567 3. Prioritizes keeping tool call chains intact 1568 4. Preserves important context from earlier exchanges 1569 5. Aggressively prunes redundant information before essential content 1570 1571 Returns: 1572 bool: True if all messages were trimmed (context exceeded), False otherwise. 1573 """ 1574 # Check if we need to trim 1575 token_count = self._count_tokens(self.history) 1576 1577 # If we're already under the limit, return early 1578 if token_count <= self.context_length: 1579 return False 1580 1581 self._log(f"[TRIM] Starting message cycling: {token_count} tokens exceeds {self.context_length} limit", level="info") 1582 1583 # We'll need to track tokens as we reconstruct the history 1584 remaining_tokens = token_count 1585 target_tokens = max(self.context_length * 0.8, self.context_length - 1000) # Target 80% or 1000 less than max 1586 1587 # First pass: identify critical messages we must keep 1588 must_keep = [] 1589 tool_chain_groups = {} # Group related tool calls and their results 1590 1591 # Always keep system messages (should be first) 1592 system_indices = [] 1593 for i, msg in enumerate(self.history): 1594 if msg.get("role") == "system": 1595 system_indices.append(i) 1596 must_keep.append(i) 1597 1598 # Identify the most recent complete exchange (user question + assistant response) 1599 latest_exchange = [] 1600 # Start from the end and work backward to find the last complete exchange 1601 for i in range(len(self.history) - 1, -1, -1): 1602 msg = self.history[i] 1603 if msg.get("role") == "assistant" and not latest_exchange: 1604 latest_exchange.append(i) 1605 elif msg.get("role") == "user" and latest_exchange: 1606 latest_exchange.append(i) 1607 break 1608 1609 # Add the latest exchange to must-keep 1610 must_keep.extend(latest_exchange) 1611 1612 # Identify tool chains - track which messages belong to the same tool flow 1613 tool_id_to_chain = {} 1614 for i, msg in enumerate(self.history): 1615 # For assistant messages with tool calls 1616 if msg.get("role") == "assistant" and msg.get("tool_calls"): 1617 for tool_call in msg.get("tool_calls"): 1618 tool_id = tool_call.get("id") 1619 if tool_id: 1620 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 1621 1622 # For tool response messages 1623 elif msg.get("role") == "tool" and msg.get("tool_call_id"): 1624 tool_id = msg.get("tool_call_id") 1625 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 1626 1627 # For Anthropic format with tool use 1628 elif msg.get("role") == "assistant" and isinstance(msg.get("content"), list): 1629 for block in msg.get("content", []): 1630 if isinstance(block, dict) and block.get("type") == "tool_use": 1631 tool_id = block.get("id") 1632 if tool_id: 1633 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 1634 1635 # For Anthropic tool result messages 1636 elif msg.get("role") == "user" and isinstance(msg.get("content"), list): 1637 for block in msg.get("content", []): 1638 if isinstance(block, dict) and block.get("type") == "tool_result": 1639 tool_id = block.get("tool_use_id") 1640 if tool_id: 1641 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 1642 1643 # Group together all indices for each tool chain 1644 for tool_id, indices in tool_id_to_chain.items(): 1645 chain_key = f"tool_{min(indices)}" # Group by the earliest message 1646 if chain_key not in tool_chain_groups: 1647 tool_chain_groups[chain_key] = set() 1648 tool_chain_groups[chain_key].update(indices) 1649 1650 # Second pass: calculate tokens for each message 1651 message_tokens = [] 1652 for i, msg in enumerate(self.history): 1653 # Count tokens for this individual message 1654 tokens = self._count_tokens([msg]) 1655 message_tokens.append((i, tokens)) 1656 1657 # Keep the messages identified as must-keep 1658 keep_indices = set(must_keep) 1659 1660 # Calculate the tokens we've committed to keeping 1661 keep_tokens = sum(tokens for i, tokens in message_tokens if i in keep_indices) 1662 1663 # Check if we've already exceeded the target with just must-keep messages 1664 if keep_tokens > self.context_length: 1665 # We're in trouble - the essential messages alone exceed context 1666 # Drop older messages until we're under the limit 1667 all_indices = sorted(keep_indices) 1668 1669 # Start dropping oldest messages, but NEVER drop system messages 1670 for idx in all_indices: 1671 if idx not in system_indices: 1672 keep_indices.remove(idx) 1673 keep_tokens -= message_tokens[idx][1] 1674 if keep_tokens <= target_tokens: 1675 break 1676 1677 # If we've removed everything but system messages and still over limit 1678 if keep_tokens > self.context_length: 1679 self._log(f"[TRIM] Critical failure: even with minimal context ({keep_tokens} tokens), we exceed the limit", level="error") 1680 # Keep only system messages if any 1681 keep_indices = set(system_indices) 1682 return True # Context exceeded completely 1683 1684 # Third pass: keep the most important tool chains intact 1685 available_tokens = target_tokens - keep_tokens 1686 # Sort tool chains by recency (assumed by the chain_key which uses the earliest message) 1687 sorted_chains = sorted(tool_chain_groups.items(), key=lambda x: x[0], reverse=True) 1688 1689 for chain_key, indices in sorted_chains: 1690 # Skip if we've already decided to keep all messages in this chain 1691 if indices.issubset(keep_indices): 1692 continue 1693 1694 # Calculate how many tokens this chain would add 1695 chain_tokens = sum(tokens for i, tokens in message_tokens if i in indices and i not in keep_indices) 1696 1697 # If we can fit the entire chain, keep it 1698 if chain_tokens <= available_tokens: 1699 keep_indices.update(indices) 1700 available_tokens -= chain_tokens 1701 # Otherwise, we might want to keep partial chains in the future, but for now, skip 1702 1703 # Fourth pass: fill in with as many remaining messages as possible, prioritizing recency 1704 # Get remaining messages sorted by recency (newest first) 1705 remaining_indices = [(i, tokens) for i, tokens in message_tokens if i not in keep_indices] 1706 remaining_indices.sort(reverse=True) # Sort newest first 1707 1708 for i, tokens in remaining_indices: 1709 if tokens <= available_tokens: 1710 keep_indices.add(i) 1711 available_tokens -= tokens 1712 1713 # Final message reconstruction 1714 self._log(f"[TRIM] Keeping {len(keep_indices)}/{len(self.history)} messages, estimated {target_tokens - available_tokens} tokens", level="info") 1715 1716 # Create new history with just the kept messages, preserving order 1717 new_history = [self.history[i] for i in sorted(keep_indices)] 1718 self.history = new_history 1719 1720 # Update session_history to match the pruned history 1721 if hasattr(self, "session_history"): 1722 # Map between history items and session_history 1723 session_to_keep = [] 1724 1725 # For each session history message, check if it corresponds to a kept message 1726 for session_msg in self.session_history: 1727 # Keep system messages 1728 if session_msg.get("role") == "system": 1729 session_to_keep.append(session_msg) 1730 continue 1731 1732 # Try to match based on available IDs or content 1733 msg_id = session_msg.get("id") 1734 1735 # For tool messages, check tool_info.id against tool_call_id 1736 if "metadata" in session_msg and "tool_info" in session_msg["metadata"]: 1737 tool_id = session_msg["metadata"]["tool_info"].get("id") 1738 1739 # Check if this tool_id is still in the kept history 1740 for history_msg in new_history: 1741 # Check standard ids 1742 history_tool_id = None 1743 1744 # Check OpenAI format 1745 if history_msg.get("role") == "tool": 1746 history_tool_id = history_msg.get("tool_call_id") 1747 elif history_msg.get("role") == "assistant" and history_msg.get("tool_calls"): 1748 for call in history_msg.get("tool_calls", []): 1749 if call.get("id") == tool_id: 1750 history_tool_id = call.get("id") 1751 break 1752 1753 # Check Anthropic format 1754 elif isinstance(history_msg.get("content"), list): 1755 for block in history_msg.get("content", []): 1756 if isinstance(block, dict): 1757 if block.get("type") == "tool_use" and block.get("id") == tool_id: 1758 history_tool_id = block.get("id") 1759 break 1760 elif block.get("type") == "tool_result" and block.get("tool_use_id") == tool_id: 1761 history_tool_id = block.get("tool_use_id") 1762 break 1763 1764 if history_tool_id == tool_id: 1765 session_to_keep.append(session_msg) 1766 break 1767 1768 # For regular messages, try content matching as fallback 1769 else: 1770 content_match = False 1771 if isinstance(session_msg.get("content"), str) and session_msg.get("content"): 1772 for history_msg in new_history: 1773 if history_msg.get("role") == session_msg.get("role") and history_msg.get("content") == session_msg.get("content"): 1774 content_match = True 1775 break 1776 1777 if content_match: 1778 session_to_keep.append(session_msg) 1779 1780 # Update session_history with kept messages 1781 self.session_history = session_to_keep 1782 1783 # Re-normalize to ensure consistency 1784 self._normalizer(force=True) 1785 1786 # Verify our final token count 1787 final_token_count = self._count_tokens(self.history) 1788 self._log(f"[TRIM] Final history has {len(self.history)} messages, {final_token_count} tokens", level="info") 1789 1790 # Return whether we've completely exceeded context 1791 return final_token_count > self.context_length or len(self.history) == len(system_indices) 1792 1793 1794 def messages_add( 1795 self, 1796 role: str, 1797 content: Any, 1798 tool_info: Optional[Dict] = None, 1799 normalize: bool = True 1800 ) -> str: 1801 """ 1802 Add a message to the standardized session_history and then update SDK-specific history. 1803 1804 This method is the central point for all message additions to the conversation. 1805 1806 Args: 1807 role: The role of the message ("user", "assistant", "system", "tool") 1808 content: The message content (text or structured) 1809 tool_info: Optional tool-related metadata 1810 normalize: Whether to normalize history after adding this message 1811 1812 Returns: 1813 str: Unique ID of the added message 1814 """ 1815 # Generate a unique message ID 1816 message_id = str(uuid.uuid4()) 1817 1818 # Create the standardized message for session_history 1819 timestamp = datetime.now(timezone.utc).isoformat() 1820 1821 # Store system messages directly 1822 if role == "system": 1823 self.system = content 1824 1825 # Create standard format message 1826 standard_message = { 1827 "role": role, 1828 "content": content, 1829 "id": message_id, 1830 "timestamp": timestamp, 1831 "metadata": { 1832 "sdk": self.sdk 1833 } 1834 } 1835 1836 # Add tool info if provided 1837 if tool_info: 1838 standard_message["metadata"]["tool_info"] = tool_info 1839 1840 # Add to session_history 1841 if not hasattr(self, "session_history"): 1842 self.session_history = [] 1843 1844 self.session_history.append(standard_message) 1845 1846 # Save to persistent session if enabled 1847 if self.session_enabled and self.session_id: 1848 # Convert standard message to session-compatible format 1849 session_msg = { 1850 "role": role, 1851 "content": content, 1852 "id": message_id, 1853 "timestamp": timestamp 1854 } 1855 1856 # Add tool-related fields if present 1857 if tool_info: 1858 for key, value in tool_info.items(): 1859 session_msg[key] = value 1860 1861 # Store in session 1862 self.session.msg_insert(self.session_id, session_msg) 1863 1864 # Update the SDK-specific format in self.history by running the normalizer 1865 if normalize: 1866 # We only need to normalize the most recent message for efficiency 1867 # Pass a flag indicating we're just normalizing a new message 1868 self._normalizer(force=False, new_message_only=True) 1869 1870 # Log the added message 1871 self._log(f"[MESSAGE ADDED] {role}: {str(content)[:50]}...") 1872 1873 return message_id 1874 1875 def messages_system(self, prompt: str): 1876 """Set or retrieve the current system prompt.""" 1877 if not isinstance(prompt, str) or not prompt: 1878 return self.system 1879 1880 # If the prompt hasn't changed, don't do anything 1881 if self.system == prompt: 1882 return self.system 1883 1884 # Update the system prompt 1885 old_system = self.system 1886 self.system = prompt 1887 1888 # For OpenAI, update or insert the system message in history 1889 if self.sdk == "openai": 1890 # Check if there's already a system message 1891 system_index = next((i for i, msg in enumerate(self.history) 1892 if msg.get("role") == "system"), None) 1893 1894 if system_index is not None: 1895 # Update existing system message 1896 self.history[system_index]["content"] = prompt 1897 else: 1898 # Insert new system message at the beginning 1899 self.history.insert(0, {"role": "system", "content": prompt}) 1900 1901 # For Anthropic, system message is not part of history, just save it for API calls 1902 1903 # Log to session only if prompt actually changed 1904 if self.session_enabled and self.session_id and old_system != prompt: 1905 self.session.msg_insert(self.session_id, {"role": "system", "content": prompt}) 1906 1907 return self.system 1908 1909 def messages(self) -> list: 1910 """Return full session messages (persisted or in-memory).""" 1911 if self.session_enabled and self.session_id: 1912 return self.session.load_full(self.session_id).get("messages", []) 1913 return self.session_history 1914 1915 def messages_length(self) -> int: 1916 """Calculate the total token count for the message history.""" 1917 if not self.encoding: 1918 return 0 1919 1920 total_tokens = 0 1921 for message in self.history: 1922 if message.get("content"): 1923 total_tokens += len(self.encoding.encode(message["content"])) 1924 if message.get("tool_calls"): 1925 for tool_call in message["tool_calls"]: 1926 if tool_call.get("function"): 1927 total_tokens += len(self.encoding.encode(tool_call["function"].get("name", ""))) 1928 total_tokens += len(self.encoding.encode(tool_call["function"].get("arguments", ""))) 1929 return total_tokens 1930 1931 def session_load(self, session_id: Optional[str]): 1932 """Load and normalize messages for a specific session.""" 1933 self.session_id = session_id 1934 self._last_session_id = session_id 1935 1936 if self.session_enabled and session_id: 1937 try: 1938 # Load raw session data 1939 session_data = self.session.load_full(session_id) 1940 messages = session_data.get("messages", []) 1941 1942 # Convert session format to our standard format 1943 self.session_history = [] 1944 1945 # Track the most recent system message 1946 latest_system_msg = None 1947 1948 for msg in messages: 1949 # Extract fields 1950 role = msg.get("role", "user") 1951 content = msg.get("content", "") 1952 msg_id = msg.get("id", str(uuid.uuid4())) 1953 timestamp = msg.get("timestamp", datetime.now(timezone.utc).isoformat()) 1954 1955 # If this is a system message, track it but don't add to session_history yet 1956 if role == "system": 1957 if latest_system_msg is None or timestamp > latest_system_msg["timestamp"]: 1958 latest_system_msg = { 1959 "role": role, 1960 "content": content, 1961 "id": msg_id, 1962 "timestamp": timestamp, 1963 "metadata": {"sdk": self.sdk} 1964 } 1965 continue 1966 1967 # Build tool_info if present 1968 tool_info = None 1969 if any(key in msg for key in ["tool_use_id", "tool_call_id", "name", "arguments"]): 1970 tool_info = { 1971 "id": msg.get("tool_use_id") or msg.get("tool_call_id"), 1972 "name": msg.get("name", "unknown_tool"), 1973 "arguments": msg.get("arguments", {}) 1974 } 1975 1976 # Create standard message 1977 standard_msg = { 1978 "role": role, 1979 "content": content, 1980 "id": msg_id, 1981 "timestamp": timestamp, 1982 "metadata": { 1983 "sdk": self.sdk 1984 } 1985 } 1986 1987 if tool_info: 1988 standard_msg["metadata"]["tool_info"] = tool_info 1989 1990 self.session_history.append(standard_msg) 1991 1992 # If we found a system message, update the system property and add to history 1993 if latest_system_msg: 1994 self.system = latest_system_msg["content"] 1995 # Insert at the beginning of session_history 1996 self.session_history.insert(0, latest_system_msg) 1997 else: 1998 # If no system message was found, add the current system message 1999 self.messages_add(role="system", content=self.system) 2000 2001 # Normalize to current SDK format 2002 self._normalizer(force=True) 2003 2004 self._log(f"[SESSION] Switched to session '{session_id}'") 2005 except Exception as e: 2006 self.logger.error(f"Failed to load session '{session_id}': {e}") 2007 self.session_reset() 2008 else: 2009 # Reset to empty state with system message 2010 self.session_reset() 2011 2012 def session_reset(self): 2013 """ 2014 Reset the current session state and reinitialize to default system prompt. 2015 2016 Clears history, disables session ID tracking, and returns to in-memory mode. 2017 """ 2018 self.session_id = None 2019 self._last_session_id = None 2020 2021 # Clear histories 2022 self.session_history = [] 2023 self.history = [] 2024 2025 # Reapply the system message 2026 if hasattr(self, "system") and self.system: 2027 # Add to session_history 2028 self.messages_add(role="system", content=self.system) 2029 else: 2030 # Ensure we have a default system message 2031 self.system = "You are a helpful Assistant." 2032 self.messages_add(role="system", content=self.system) 2033 2034 self._log("[SESSION] Reset to in-memory mode") 2035 2036 def _normalizer(self, force=False, new_message_only=False): 2037 """ 2038 Central normalization function that transforms the standard session_history 2039 into the SDK-specific format needed in self.history. 2040 2041 Args: 2042 force (bool): If True, always normalize even if SDK hasn't changed. 2043 Default is False, which only normalizes on SDK change. 2044 new_message_only (bool): If True, only normalize the most recent message 2045 for efficiency when adding single messages. 2046 """ 2047 # Skip normalization if SDK hasn't changed and force is False 2048 if not force and hasattr(self, '_last_sdk') and self._last_sdk == self.sdk: 2049 # If we only need to normalize the most recent message 2050 if new_message_only and self.session_history: 2051 # Get the most recent message from session_history 2052 recent_msg = self.session_history[-1] 2053 2054 # Apply SDK-specific normalization for just this message 2055 if self.sdk == "openai": 2056 self._openai_normalize_message(recent_msg) 2057 elif self.sdk == "anthropic": 2058 self._anthropic_normalize_message(recent_msg) 2059 else: 2060 # Generic handler for unknown SDKs 2061 self._generic_normalize_message(recent_msg) 2062 2063 return 2064 2065 # Record the current SDK to detect future changes 2066 self._last_sdk = self.sdk 2067 2068 # For full normalization, clear current history and rebuild it 2069 self.history = [] 2070 2071 # Call the appropriate SDK-specific normalizer 2072 if self.sdk == "openai": 2073 self._openai_normalizer() 2074 elif self.sdk == "anthropic": 2075 self._anthropic_normalizer() 2076 else: 2077 self.logger.warning(f"No normalizer available for SDK: {self.sdk}") 2078 # Fallback to a simple conversion for unknown SDKs 2079 for msg in self.session_history: 2080 self._generic_normalize_message(msg) 2081 2082 def _openai_normalizer(self): 2083 """ 2084 Convert standardized session_history to OpenAI-compatible format in self.history. 2085 """ 2086 # For OpenAI, we need to include system message in the history 2087 # and convert tool calls/results to OpenAI format 2088 2089 # Start with empty history 2090 self.history = [] 2091 2092 # First, add the current system message at position 0 2093 self.history.append({ 2094 "role": "system", 2095 "content": self.system 2096 }) 2097 2098 # Process all non-system messages 2099 for msg in self.session_history: 2100 if msg["role"] == "system": 2101 continue # Skip system messages, already handled 2102 2103 # Handle different message types 2104 if msg["role"] == "user": 2105 # User messages are straightforward 2106 self.history.append({ 2107 "role": "user", 2108 "content": msg["content"] 2109 }) 2110 2111 elif msg["role"] == "assistant": 2112 # For assistant messages with tool calls 2113 if "metadata" in msg and msg["metadata"].get("tool_info"): 2114 # This is an assistant message with tool calls 2115 tool_info = msg["metadata"]["tool_info"] 2116 2117 # Create OpenAI assistant message with tool calls 2118 assistant_msg = { 2119 "role": "assistant", 2120 "content": msg["content"] if isinstance(msg["content"], str) else "", 2121 "tool_calls": [{ 2122 "id": tool_info["id"], 2123 "type": "function", 2124 "function": { 2125 "name": tool_info["name"], 2126 "arguments": json.dumps(tool_info["arguments"]) if isinstance(tool_info["arguments"], dict) else tool_info["arguments"] 2127 } 2128 }] 2129 } 2130 self.history.append(assistant_msg) 2131 else: 2132 # Regular assistant message 2133 self.history.append({ 2134 "role": "assistant", 2135 "content": msg["content"] 2136 }) 2137 2138 elif msg["role"] == "tool": 2139 # Tool response messages 2140 if "metadata" in msg and "tool_info" in msg["metadata"]: 2141 tool_msg = { 2142 "role": "tool", 2143 "tool_call_id": msg["metadata"]["tool_info"]["id"], 2144 "content": json.dumps(msg["content"]) if isinstance(msg["content"], (dict, list)) else msg["content"] 2145 } 2146 self.history.append(tool_msg) 2147 2148 def _anthropic_normalizer(self): 2149 """ 2150 Convert standardized session_history to Anthropic-compatible format in self.history. 2151 """ 2152 # For Anthropic, we don't include system message in the history 2153 # but need to handle content blocks for tool use/results 2154 2155 # Start with empty history 2156 self.history = [] 2157 2158 # Process all non-system messages 2159 for msg in self.session_history: 2160 if msg["role"] == "system": 2161 # Update system prompt if this is the most recent system message 2162 # (only apply the most recent system message if we have multiple) 2163 if msg == self.session_history[-1] or all(m["role"] != "system" for m in self.session_history[self.session_history.index(msg)+1:]): 2164 self.system = msg["content"] 2165 continue # Skip system messages in history 2166 2167 # Handle different message types 2168 if msg["role"] == "user": 2169 # User messages - check if it contains tool results 2170 if "metadata" in msg and "tool_info" in msg["metadata"] and msg["metadata"]["tool_info"].get("result"): 2171 # This is a tool result message 2172 tool_info = msg["metadata"]["tool_info"] 2173 2174 # Create Anthropic tool result format 2175 tool_result_msg = { 2176 "role": "user", 2177 "content": [{ 2178 "type": "tool_result", 2179 "tool_use_id": tool_info["id"], 2180 "content": json.dumps(tool_info["result"]) if isinstance(tool_info["result"], (dict, list)) else str(tool_info["result"]) 2181 }] 2182 } 2183 self.history.append(tool_result_msg) 2184 else: 2185 # Regular user message 2186 self.history.append({ 2187 "role": "user", 2188 "content": msg["content"] 2189 }) 2190 2191 elif msg["role"] == "assistant": 2192 # For assistant messages, check for tool use 2193 if "metadata" in msg and "tool_info" in msg["metadata"]: 2194 # This is an assistant message with tool use 2195 tool_info = msg["metadata"]["tool_info"] 2196 current_tool_use_id = tool_info["id"] 2197 2198 # Build content blocks 2199 content_blocks = [] 2200 2201 # Add text content if present 2202 if msg["content"]: 2203 content_blocks.append({ 2204 "type": "text", 2205 "text": msg["content"] if isinstance(msg["content"], str) else "" 2206 }) 2207 2208 # Add tool use block 2209 content_blocks.append({ 2210 "type": "tool_use", 2211 "id": tool_info["id"], 2212 "name": tool_info["name"], 2213 "input": tool_info["arguments"] if isinstance(tool_info["arguments"], dict) else json.loads(tool_info["arguments"]) 2214 }) 2215 2216 # Create Anthropic assistant message with tool use 2217 self.history.append({ 2218 "role": "assistant", 2219 "content": content_blocks 2220 }) 2221 else: 2222 # Regular assistant message 2223 self.history.append({ 2224 "role": "assistant", 2225 "content": msg["content"] 2226 }) 2227 2228 elif msg["role"] == "tool": 2229 # Tool messages in standard format get converted to user messages with tool_result 2230 if "metadata" in msg and "tool_info" in msg["metadata"]: 2231 tool_info = msg["metadata"]["tool_info"] 2232 2233 # Create Anthropic tool result message 2234 tool_result_msg = { 2235 "role": "user", 2236 "content": [{ 2237 "type": "tool_result", 2238 "tool_use_id": tool_info["id"], 2239 "content": json.dumps(msg["content"]) if isinstance(msg["content"], (dict, list)) else str(msg["content"]) 2240 }] 2241 } 2242 self.history.append(tool_result_msg) 2243 2244 def _openai_normalize_message(self, msg): 2245 """Normalize a single message to OpenAI format and add to history.""" 2246 role = msg.get("role") 2247 content = msg.get("content") 2248 2249 if role == "system": 2250 # Check if we already have a system message in history 2251 system_index = next((i for i, m in enumerate(self.history) 2252 if m.get("role") == "system"), None) 2253 if system_index is not None: 2254 # Update existing system message 2255 self.history[system_index]["content"] = content 2256 else: 2257 # Insert new system message at the beginning 2258 self.history.insert(0, { 2259 "role": "system", 2260 "content": content 2261 }) 2262 # Update the system property 2263 self.system = content 2264 2265 elif role == "user": 2266 self.history.append({ 2267 "role": "user", 2268 "content": content 2269 }) 2270 2271 elif role == "assistant": 2272 # For assistant messages, handle potential tool calls 2273 if "metadata" in msg and msg["metadata"].get("tool_info"): 2274 # This is an assistant message with tool calls 2275 tool_info = msg["metadata"]["tool_info"] 2276 2277 # Create OpenAI assistant message with tool calls 2278 try: 2279 arguments = tool_info.get("arguments", {}) 2280 arguments_str = json.dumps(arguments) if isinstance(arguments, dict) else arguments 2281 except: 2282 arguments_str = str(arguments) 2283 2284 assistant_msg = { 2285 "role": "assistant", 2286 "content": content if isinstance(content, str) else "", 2287 "tool_calls": [{ 2288 "id": tool_info["id"], 2289 "type": "function", 2290 "function": { 2291 "name": tool_info["name"], 2292 "arguments": arguments_str 2293 } 2294 }] 2295 } 2296 self.history.append(assistant_msg) 2297 else: 2298 # Regular assistant message 2299 self.history.append({ 2300 "role": "assistant", 2301 "content": content 2302 }) 2303 2304 elif role == "tool": 2305 # Tool response messages 2306 if "metadata" in msg and "tool_info" in msg["metadata"]: 2307 tool_info = msg["metadata"]["tool_info"] 2308 tool_msg = { 2309 "role": "tool", 2310 "tool_call_id": tool_info["id"], 2311 "content": json.dumps(content) if isinstance(content, (dict, list)) else str(content) 2312 } 2313 self.history.append(tool_msg) 2314 2315 def _anthropic_normalize_message(self, msg): 2316 """Normalize a single message to Anthropic format and add to history.""" 2317 role = msg.get("role") 2318 content = msg.get("content") 2319 2320 if role == "system": 2321 # Store system prompt separately, not in history for Anthropic 2322 self.system = content 2323 2324 elif role == "user": 2325 # User messages - check if it contains tool results 2326 if "metadata" in msg and "tool_info" in msg["metadata"]: 2327 tool_info = msg["metadata"]["tool_info"] 2328 # Check for result or directly use content 2329 result_content = tool_info.get("result", content) 2330 2331 # Create Anthropic tool result format 2332 try: 2333 result_str = json.dumps(result_content) if isinstance(result_content, (dict, list)) else str(result_content) 2334 except: 2335 result_str = str(result_content) 2336 2337 tool_result_msg = { 2338 "role": "user", 2339 "content": [{ 2340 "type": "tool_result", 2341 "tool_use_id": tool_info["id"], 2342 "content": result_str 2343 }] 2344 } 2345 self.history.append(tool_result_msg) 2346 else: 2347 # Regular user message 2348 self.history.append({ 2349 "role": "user", 2350 "content": content 2351 }) 2352 2353 elif role == "assistant": 2354 # For assistant messages, check for tool use 2355 if "metadata" in msg and "tool_info" in msg["metadata"]: 2356 # This is an assistant message with tool use 2357 tool_info = msg["metadata"]["tool_info"] 2358 2359 # Build content blocks 2360 content_blocks = [] 2361 2362 # Add text content if present 2363 if content: 2364 content_blocks.append({ 2365 "type": "text", 2366 "text": content if isinstance(content, str) else "" 2367 }) 2368 2369 # Add tool use block - safely convert arguments 2370 try: 2371 # Parse arguments to ensure it's a dictionary 2372 if isinstance(tool_info["arguments"], str): 2373 try: 2374 args_dict = json.loads(tool_info["arguments"]) 2375 except json.JSONDecodeError: 2376 args_dict = {"text": tool_info["arguments"]} 2377 else: 2378 args_dict = tool_info["arguments"] 2379 except: 2380 args_dict = {"error": "Failed to parse arguments"} 2381 2382 content_blocks.append({ 2383 "type": "tool_use", 2384 "id": tool_info["id"], 2385 "name": tool_info["name"], 2386 "input": args_dict 2387 }) 2388 2389 # Create Anthropic assistant message with tool use 2390 self.history.append({ 2391 "role": "assistant", 2392 "content": content_blocks 2393 }) 2394 else: 2395 # Regular assistant message 2396 self.history.append({ 2397 "role": "assistant", 2398 "content": content 2399 }) 2400 2401 elif role == "tool": 2402 # Tool messages in standard format get converted to user messages with tool_result 2403 if "metadata" in msg and "tool_info" in msg["metadata"]: 2404 tool_info = msg["metadata"]["tool_info"] 2405 2406 try: 2407 result_str = json.dumps(content) if isinstance(content, (dict, list)) else str(content) 2408 except: 2409 result_str = str(content) 2410 2411 # Create Anthropic tool result message 2412 tool_result_msg = { 2413 "role": "user", 2414 "content": [{ 2415 "type": "tool_result", 2416 "tool_use_id": tool_info["id"], 2417 "content": result_str 2418 }] 2419 } 2420 self.history.append(tool_result_msg) 2421 2422 def _generic_normalize_message(self, msg): 2423 """Generic normalizer for unknown SDKs.""" 2424 role = msg.get("role") 2425 content = msg.get("content") 2426 2427 if role in ["user", "assistant", "system"]: 2428 self.history.append({ 2429 "role": role, 2430 "content": content 2431 }) 2432 2433 def track_token_usage(self): 2434 """Track and return token usage across the conversation history. 2435 2436 Returns: 2437 dict: Dictionary containing current token counts, limits, and history. 2438 """ 2439 if not hasattr(self, "_token_history"): 2440 self._token_history = [] 2441 2442 # Count current tokens 2443 current_count = self._count_tokens(self.history) 2444 2445 # Add to history 2446 timestamp = datetime.now(timezone.utc).isoformat() 2447 self._token_history.append({ 2448 "timestamp": timestamp, 2449 "count": current_count, 2450 "limit": self.context_length, 2451 "provider": self.provider, 2452 "model": self.model 2453 }) 2454 2455 # Keep only the last 100 measurements to avoid unlimited growth 2456 if len(self._token_history) > 100: 2457 self._token_history = self._token_history[-100:] 2458 2459 # Return current tracking info 2460 return { 2461 "current": current_count, 2462 "limit": self.context_length, 2463 "percentage": round((current_count / self.context_length) * 100, 1) if self.context_length else 0, 2464 "history": self._token_history[-10:], # Return last 10 measurements 2465 "provider": self.provider, 2466 "model": self.model 2467 } 2468 2469 def get_message_token_breakdown(self): 2470 """Analyze token usage by message type and provide a detailed breakdown. 2471 2472 Returns: 2473 dict: Token usage broken down by message types and roles. 2474 """ 2475 breakdown = { 2476 "total": 0, 2477 "by_role": { 2478 "system": 0, 2479 "user": 0, 2480 "assistant": 0, 2481 "tool": 0 2482 }, 2483 "by_type": { 2484 "text": 0, 2485 "tool_calls": 0, 2486 "tool_results": 0 2487 }, 2488 "messages": [] 2489 } 2490 2491 # Analyze each message 2492 for i, msg in enumerate(self.history): 2493 msg_tokens = self._count_tokens([msg]) 2494 role = msg.get("role", "unknown") 2495 2496 # Track by role 2497 if role in breakdown["by_role"]: 2498 breakdown["by_role"][role] += msg_tokens 2499 2500 # Track by content type 2501 if role == "assistant" and msg.get("tool_calls"): 2502 breakdown["by_type"]["tool_calls"] += msg_tokens 2503 elif role == "tool": 2504 breakdown["by_type"]["tool_results"] += msg_tokens 2505 else: 2506 breakdown["by_type"]["text"] += msg_tokens 2507 2508 # Add individual message data 2509 breakdown["messages"].append({ 2510 "index": i, 2511 "role": role, 2512 "tokens": msg_tokens, 2513 "has_tools": bool(msg.get("tool_calls") or msg.get("tool_use") or 2514 (isinstance(msg.get("content"), list) and 2515 any(isinstance(c, dict) and c.get("type") in ["tool_use", "tool_result"] 2516 for c in msg.get("content", [])))) 2517 }) 2518 2519 # Update tota? 2520 breakdown["total"] += msg_tokens 2521 2522 return breakdown
62 def __init__( 63 self, 64 base_url: Optional[str] = None, 65 api_key: Optional[str] = None, 66 model: str = "openai:gpt-4o-mini", 67 fallback_model = "ollama:mistral-nemo:latest", 68 tools: Optional[bool] = True, 69 stream: bool = True, 70 context_length: int = 128000, 71 max_retries: int = 3, 72 retry_delay: float = 1.0, 73 log_path: Optional[str] = None, 74 session_enabled: bool = False, 75 session_id: Optional[str] = None, 76 session_path: Optional[str] = None 77 ): 78 """Initialize the universal AI interaction client. 79 80 Args: 81 base_url: Optional base URL for the API. If None, uses the provider's default URL. 82 api_key: Optional API key. If None, attempts to use environment variables based on provider. 83 model: Model identifier in format "provider:model_name" (e.g., "openai:gpt-4o-mini"). 84 tools: Enable (True) or disable (False) tool calling; None for auto-detection based on model support. 85 stream: Enable (True) or disable (False) streaming responses. 86 context_length: Maximum number of tokens to maintain in conversation history. 87 max_retries: Maximum number of retries for failed API calls. 88 retry_delay: Initial delay (in seconds) for exponential backoff retries. 89 session_enabled: Enable persistent session support. 90 session_id: Optional session ID to load messages from. 91 92 Raises: 93 ValueError: If provider is not supported or API key is missing for non-Ollama providers. 94 """ 95 self.system = "You are a helpful Assistant." 96 self.logger = logging.getLogger(f"InteractorLogger_{id(self)}") 97 self.logger.setLevel(logging.DEBUG) 98 self.providers = { 99 "openai": { 100 "sdk": "openai", 101 "base_url": "https://api.openai.com/v1", 102 "api_key": api_key or os.getenv("OPENAI_API_KEY") or None 103 }, 104 "ollama": { 105 "sdk": "openai", 106 "base_url": "http://localhost:11434/v1", 107 "api_key": api_key or "ollama" 108 }, 109 "nvidia": { 110 "sdk": "openai", 111 "base_url": "https://integrate.api.nvidia.com/v1", 112 "api_key": api_key or os.getenv("NVIDIA_API_KEY") or None 113 }, 114 "google": { 115 "sdk": "openai", 116 "base_url": "https://generativelanguage.googleapis.com/v1beta/openai", 117 "api_key": api_key or os.getenv("GEMINI_API_KEY") or None 118 }, 119 "anthropic": { 120 "sdk": "anthropic", 121 "base_url": "https://api.anthropic.com/v1", 122 "api_key": api_key or os.getenv("ANTHROPIC_API_KEY") or None 123 }, 124 "mistral": { 125 "sdk": "openai", 126 "base_url": "https://api.mistral.ai/v1", 127 "api_key": api_key or os.getenv("MISTRAL_API_KEY") or None 128 }, 129 "deepseek": { 130 "sdk": "openai", 131 "base_url": "https://api.deepseek.com", 132 "api_key": api_key or os.getenv("DEEPSEEK_API_KEY") or None 133 }, 134 "grok": { 135 "sdk": "grok", 136 "base_url": "https://api.x.ai/v1", 137 "api_key": api_key or os.getenv("GROK_API_KEY") or None 138 } 139 } 140 141 142 # Console log handler (always enabled at WARNING+) 143 if not self.logger.handlers: 144 console_handler = logging.StreamHandler(sys.stdout) 145 console_handler.setLevel(logging.WARNING) 146 console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) 147 self.logger.addHandler(console_handler) 148 149 self._log_enabled = False 150 if log_path: 151 file_handler = logging.FileHandler(log_path) 152 file_handler.setLevel(logging.DEBUG) 153 file_handler.setFormatter(logging.Formatter( 154 "%(asctime)s - %(levelname)s - %(message)s", 155 datefmt="%Y-%m-%d %H:%M:%S" 156 )) 157 self.logger.addHandler(file_handler) 158 self._log_enabled = True 159 160 161 self.token_estimate = 0 162 self.last_token_estimate = 0 163 self.stream = stream 164 self.tools = [] 165 self.session_history = [] 166 self.history = [] 167 self.context_length = context_length 168 self.encoding = None 169 self.max_retries = max_retries 170 self.retry_delay = retry_delay 171 self.reveal_tool = [] 172 self.fallback_model = fallback_model 173 self.sdk = None 174 175 # Session support 176 self.session_enabled = session_enabled 177 self.session_id = session_id 178 self._last_session_id = session_id 179 self.session = Session(directory=session_path) if session_enabled else None 180 181 182 if model is None: 183 model = "openai:gpt-4o-mini" 184 185 # Initialize model + encoding 186 self._setup_client(model, base_url, api_key) 187 self.tools_enabled = self.tools_supported if tools is None else tools and self.tools_supported 188 self._setup_encoding() 189 self.messages_add(role="system", content=self.system)
Initialize the universal AI interaction client.
Args: base_url: Optional base URL for the API. If None, uses the provider's default URL. api_key: Optional API key. If None, attempts to use environment variables based on provider. model: Model identifier in format "provider:model_name" (e.g., "openai:gpt-4o-mini"). tools: Enable (True) or disable (False) tool calling; None for auto-detection based on model support. stream: Enable (True) or disable (False) streaming responses. context_length: Maximum number of tokens to maintain in conversation history. max_retries: Maximum number of retries for failed API calls. retry_delay: Initial delay (in seconds) for exponential backoff retries. session_enabled: Enable persistent session support. session_id: Optional session ID to load messages from.
Raises: ValueError: If provider is not supported or API key is missing for non-Ollama providers.
347 def add_function( 348 self, 349 external_callable: Callable, 350 name: Optional[str] = None, 351 description: Optional[str] = None, 352 override: bool = False, 353 disabled: bool = False, 354 schema_extensions: Optional[Dict[str, Any]] = None 355 ): 356 """ 357 Register a function for LLM tool calling with full type hints and metadata. 358 359 Args: 360 external_callable (Callable): The function to register. 361 name (Optional[str]): Optional custom name. Defaults to function's __name__. 362 description (Optional[str]): Optional custom description. Defaults to first line of docstring. 363 override (bool): If True, replaces an existing tool with the same name. 364 disabled (bool): If True, registers the function in a disabled state. 365 schema_extensions (Optional[Dict[str, Any]]): Optional dictionary mapping parameter names to 366 schema extensions that override or add to the auto-generated schema. 367 368 Raises: 369 ValueError: If the callable is invalid or duplicate name found without override. 370 371 Example: 372 interactor.add_function( 373 my_tool, 374 override=True, 375 disabled=False, 376 schema_extensions={ 377 "param1": {"minimum": 0, "maximum": 100}, 378 "param2": {"format": "email"} 379 } 380 ) 381 """ 382 def _python_type_to_schema(ptype: Any) -> dict: 383 """Convert a Python type annotation to OpenAI-compatible JSON Schema.""" 384 # Handle None case 385 if ptype is None: 386 return {"type": "null"} 387 388 # Get the origin and arguments of the type 389 origin = get_origin(ptype) 390 args = get_args(ptype) 391 392 # Handle Union types (including Optional) 393 if origin is Union: 394 # Check for Optional (Union with None) 395 none_type = type(None) 396 if none_type in args: 397 non_none = [a for a in args if a is not none_type] 398 if len(non_none) == 1: 399 inner = _python_type_to_schema(non_none[0]) 400 inner_copy = inner.copy() 401 inner_copy["nullable"] = True 402 return inner_copy 403 # Multiple types excluding None 404 types = [_python_type_to_schema(a) for a in non_none] 405 return {"anyOf": types, "nullable": True} 406 # Regular Union without None 407 return {"anyOf": [_python_type_to_schema(a) for a in args]} 408 409 # Handle List and similar container types 410 if origin in (list, List): 411 item_type = args[0] if args else Any 412 if item_type is Any: 413 return {"type": "array"} 414 return {"type": "array", "items": _python_type_to_schema(item_type)} 415 416 # Handle Dict types with typing info 417 if origin in (dict, Dict): 418 if not args or len(args) != 2: 419 return {"type": "object"} 420 421 key_type, val_type = args 422 # We can only really use val_type in JSON Schema 423 if val_type is not Any and val_type is not object: 424 return { 425 "type": "object", 426 "additionalProperties": _python_type_to_schema(val_type) 427 } 428 return {"type": "object"} 429 430 # Handle Literal types for enums 431 if origin is Literal: 432 values = args 433 # Try to determine type from values 434 if all(isinstance(v, str) for v in values): 435 return {"type": "string", "enum": list(values)} 436 elif all(isinstance(v, bool) for v in values): 437 return {"type": "boolean", "enum": list(values)} 438 elif all(isinstance(v, (int, float)) for v in values): 439 return {"type": "number", "enum": list(values)} 440 else: 441 # Mixed types, use anyOf 442 return {"anyOf": [{"type": _get_json_type(v), "enum": [v]} for v in values]} 443 444 # Handle basic types 445 if ptype is str: 446 return {"type": "string"} 447 if ptype is int: 448 return {"type": "integer"} 449 if ptype is float: 450 return {"type": "number"} 451 if ptype is bool: 452 return {"type": "boolean"} 453 454 # Handle common datetime types 455 if ptype is datetime: 456 return {"type": "string", "format": "date-time"} 457 if ptype is date: 458 return {"type": "string", "format": "date"} 459 460 # Handle UUID 461 if ptype is uuid.UUID: 462 return {"type": "string", "format": "uuid"} 463 464 # Default to object for any other types 465 return {"type": "object"} 466 467 def _get_json_type(value): 468 """Get the JSON Schema type name for a Python value.""" 469 if isinstance(value, str): 470 return "string" 471 elif isinstance(value, bool): 472 return "boolean" 473 elif isinstance(value, int) or isinstance(value, float): 474 return "number" 475 elif isinstance(value, list): 476 return "array" 477 elif isinstance(value, dict): 478 return "object" 479 else: 480 return "object" # Default 481 482 def _parse_param_docs(docstring: str) -> dict: 483 """Extract parameter descriptions from a docstring.""" 484 if not docstring: 485 return {} 486 487 lines = docstring.splitlines() 488 param_docs = {} 489 current_param = None 490 in_params = False 491 492 # Regular expressions for finding parameter sections and param lines 493 param_section_re = re.compile(r"^(Args|Parameters):\s*$") 494 param_line_re = re.compile(r"^\s{4}(\w+)\s*(?:\([^\)]*\))?:\s*(.*)") 495 496 for line in lines: 497 # Check if we're entering the parameters section 498 if param_section_re.match(line.strip()): 499 in_params = True 500 continue 501 502 if in_params: 503 # Skip empty lines 504 if not line.strip(): 505 continue 506 507 # Check for a parameter definition line 508 match = param_line_re.match(line) 509 if match: 510 current_param = match.group(1) 511 param_docs[current_param] = match.group(2).strip() 512 # Check for continuation of a parameter description 513 elif current_param and line.startswith(" " * 8): 514 param_docs[current_param] += " " + line.strip() 515 # If we see a line that doesn't match our patterns, we're out of the params section 516 else: 517 current_param = None 518 519 return param_docs 520 521 # Start of main function logic 522 523 # Skip if tools are disabled 524 if not self.tools_enabled: 525 return 526 527 # Validate input callable 528 if not external_callable: 529 raise ValueError("A valid external callable must be provided.") 530 531 # Set function name, either from parameter or from callable's __name__ 532 function_name = name or external_callable.__name__ 533 534 # Try to get docstring and extract description 535 try: 536 docstring = inspect.getdoc(external_callable) 537 description = description or (docstring.split("\n")[0].strip() if docstring else "No description provided.") 538 except Exception as e: 539 self._log(f"[TOOL] Warning: Could not extract docstring from {function_name}: {e}", level="warning") 540 docstring = "" 541 description = description or "No description provided." 542 543 # Extract parameter documentation from docstring 544 param_docs = _parse_param_docs(docstring) 545 546 # Handle conflicts with existing functions 547 if override: 548 self.delete_function(function_name) 549 elif any(t["function"]["name"] == function_name for t in self.tools): 550 raise ValueError(f"Function '{function_name}' is already registered. Use override=True to replace.") 551 552 # Try to get function signature for parameter info 553 try: 554 signature = inspect.signature(external_callable) 555 except (ValueError, TypeError) as e: 556 raise ValueError(f"Cannot inspect callable '{function_name}': {e}") 557 558 # Process parameters to build schema 559 properties = {} 560 required = [] 561 562 for param_name, param in signature.parameters.items(): 563 # Skip self, cls parameters for instance/class methods 564 if param_name in ("self", "cls") and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: 565 continue 566 567 # Get parameter annotation, defaulting to Any 568 annotation = param.annotation if param.annotation != inspect.Parameter.empty else Any 569 570 try: 571 # Convert Python type to JSON Schema 572 schema = _python_type_to_schema(annotation) 573 574 # Add description from docstring or create a default one 575 schema["description"] = param_docs.get(param_name, f"{param_name} parameter") 576 577 # Add to properties 578 properties[param_name] = schema 579 580 # If no default, add to required list 581 if param.default == inspect.Parameter.empty: 582 required.append(param_name) 583 584 except Exception as e: 585 self._log(f"[TOOL] Error processing parameter {param_name} for {function_name}: {e}", level="error") 586 # Add a basic object schema as fallback 587 properties[param_name] = { 588 "type": "object", 589 "description": f"{param_name} parameter (type conversion failed)" 590 } 591 592 # Apply schema extensions if provided 593 if schema_extensions: 594 for param_name, extensions in schema_extensions.items(): 595 if param_name in properties: 596 properties[param_name].update(extensions) 597 598 # Build the final tool specification 599 tool_spec = { 600 "type": "function", 601 "function": { 602 "name": function_name, 603 "description": description, 604 "parameters": { 605 "type": "object", 606 "properties": properties, 607 "required": required 608 } 609 } 610 } 611 612 # Set disabled flag if requested 613 if disabled: 614 tool_spec["function"]["disabled"] = True 615 616 # Add to tools list 617 self.tools.append(tool_spec) 618 619 # Make the function available as an attribute on the instance 620 setattr(self, function_name, external_callable) 621 622 # Log the registration 623 self._log(f"[TOOL] Registered function '{function_name}' with {len(properties)} parameters", level="info") 624 625 return function_name # Return the name for reference
Register a function for LLM tool calling with full type hints and metadata.
Args: external_callable (Callable): The function to register. name (Optional[str]): Optional custom name. Defaults to function's __name__. description (Optional[str]): Optional custom description. Defaults to first line of docstring. override (bool): If True, replaces an existing tool with the same name. disabled (bool): If True, registers the function in a disabled state. schema_extensions (Optional[Dict[str, Any]]): Optional dictionary mapping parameter names to schema extensions that override or add to the auto-generated schema.
Raises: ValueError: If the callable is invalid or duplicate name found without override.
Example: interactor.add_function( my_tool, override=True, disabled=False, schema_extensions={ "param1": {"minimum": 0, "maximum": 100}, "param2": {"format": "email"} } )
628 def disable_function(self, name: str) -> bool: 629 """ 630 Disable a registered tool function by name. 631 632 This marks the function as inactive for tool calling without removing it from the internal registry. 633 The function remains visible in the tool listing but is skipped during tool selection by the LLM. 634 635 Args: 636 name (str): The name of the function to disable. 637 638 Returns: 639 bool: True if the function was found and disabled, False otherwise. 640 641 Example: 642 interactor.disable_function("extract_text") 643 """ 644 for tool in self.tools: 645 if tool["function"]["name"] == name: 646 tool["function"]["disabled"] = True 647 return True 648 return False
Disable a registered tool function by name.
This marks the function as inactive for tool calling without removing it from the internal registry. The function remains visible in the tool listing but is skipped during tool selection by the LLM.
Args: name (str): The name of the function to disable.
Returns: bool: True if the function was found and disabled, False otherwise.
Example: interactor.disable_function("extract_text")
651 def enable_function(self, name: str) -> bool: 652 """ 653 Re-enable a previously disabled tool function by name. 654 655 This removes the 'disabled' flag from a tool function, making it available again for LLM use. 656 657 Args: 658 name (str): The name of the function to enable. 659 660 Returns: 661 bool: True if the function was found and enabled, False otherwise. 662 663 Example: 664 interactor.enable_function("extract_text") 665 """ 666 for tool in self.tools: 667 if tool["function"]["name"] == name: 668 tool["function"].pop("disabled", None) 669 return True 670 return False
Re-enable a previously disabled tool function by name.
This removes the 'disabled' flag from a tool function, making it available again for LLM use.
Args: name (str): The name of the function to enable.
Returns: bool: True if the function was found and enabled, False otherwise.
Example: interactor.enable_function("extract_text")
673 def delete_function(self, name: str) -> bool: 674 """ 675 Permanently remove a registered tool function from the Interactor. 676 677 This deletes both the tool metadata and the callable attribute, making it fully inaccessible 678 from the active session. Useful for dynamically trimming the toolset. 679 680 Args: 681 name (str): The name of the function to delete. 682 683 Returns: 684 bool: True if the function was found and removed, False otherwise. 685 686 Example: 687 interactor.delete_function("extract_text") 688 """ 689 before = len(self.tools) 690 self.tools = [tool for tool in self.tools if tool["function"]["name"] != name] 691 if hasattr(self, name): 692 delattr(self, name) 693 return len(self.tools) < before
Permanently remove a registered tool function from the Interactor.
This deletes both the tool metadata and the callable attribute, making it fully inaccessible from the active session. Useful for dynamically trimming the toolset.
Args: name (str): The name of the function to delete.
Returns: bool: True if the function was found and removed, False otherwise.
Example: interactor.delete_function("extract_text")
696 def list_functions(self) -> List[Dict[str, Any]]: 697 """Get the list of registered functions for tool calling. 698 699 Returns: 700 List[Dict[str, Any]]: List of registered functions. 701 """ 702 return self.tools
Get the list of registered functions for tool calling.
Returns: List[Dict[str, Any]]: List of registered functions.
704 def list_models( 705 self, 706 providers: Optional[Union[str, List[str]]] = None, 707 filter: Optional[str] = None 708 ) -> List[str]: 709 """Retrieve available models from configured providers. 710 711 Args: 712 providers: Provider name or list of provider names. If None, all are queried. 713 filter: Optional regex to filter model names. 714 715 Returns: 716 List[str]: Sorted list of "provider:model_id" strings. 717 """ 718 models = [] 719 720 if providers is None: 721 providers_to_list = self.providers 722 elif isinstance(providers, str): 723 providers_to_list = {providers: self.providers.get(providers)} 724 elif isinstance(providers, list): 725 providers_to_list = {p: self.providers.get(p) for p in providers} 726 else: 727 return [] 728 729 invalid_providers = [p for p in providers_to_list if p not in self.providers or self.providers[p] is None] 730 if invalid_providers: 731 self.logger.error(f"Invalid providers: {invalid_providers}") 732 return [] 733 734 regex_pattern = None 735 if filter: 736 try: 737 regex_pattern = re.compile(filter, re.IGNORECASE) 738 except re.error as e: 739 self.logger.error(f"Invalid regex pattern: {e}") 740 return [] 741 742 for provider_name, config in providers_to_list.items(): 743 sdk = config.get("sdk", "openai") 744 base_url = config.get("base_url") 745 api_key = config.get("api_key") 746 747 try: 748 if sdk == "openai": 749 client = openai.OpenAI(api_key=api_key, base_url=base_url) 750 response = client.models.list() 751 for model in response.data: 752 model_id = f"{provider_name}:{model.id}" 753 if not regex_pattern or regex_pattern.search(model_id): 754 models.append(model_id) 755 756 elif sdk == "anthropic": 757 client = Anthropic(api_key=api_key) 758 response = client.models.list() 759 for model in response: 760 model_id = f"{provider_name}:{model.id}" 761 if not regex_pattern or regex_pattern.search(model_id): 762 models.append(model_id) 763 else: 764 self.logger.warning(f"SDK '{sdk}' for provider '{provider_name}' is not supported by list_models()") 765 766 except Exception as e: 767 self.logger.error(f"Failed to list models for {provider_name}: {e}") 768 769 return sorted(models, key=str.lower)
Retrieve available models from configured providers.
Args: providers: Provider name or list of provider names. If None, all are queried. filter: Optional regex to filter model names.
Returns: List[str]: Sorted list of "provider:model_id" strings.
803 def interact( 804 self, 805 user_input: Optional[str], 806 quiet: bool = False, 807 tools: bool = True, 808 stream: bool = True, 809 markdown: bool = False, 810 model: Optional[str] = None, 811 output_callback: Optional[Callable[[str], None]] = None, 812 session_id: Optional[str] = None 813 ) -> Optional[str]: 814 """Main universal gateway for all LLM interaction.""" 815 if not user_input: 816 return None 817 818 # Setup model if specified 819 if model: 820 self._setup_client(model) 821 self._setup_encoding() 822 823 # Session handling 824 if self.session_enabled and session_id: 825 self.session_id = session_id 826 self.session_load(session_id) 827 828 # Add user message using messages_add 829 self.messages_add(role="user", content=user_input) 830 831 # Log token count estimate 832 token_count = self._count_tokens(self.history) 833 if not quiet: 834 print(f"[dim]Estimated tokens in context: {token_count} / {self.context_length}[/dim]") 835 836 # Make sure we have enough context space 837 if token_count > self.context_length: 838 if self._cycle_messages(): 839 if not quiet: 840 print("[red]Context window exceeded. Cannot proceed.[/red]") 841 return None 842 843 # Log user input 844 self._log(f"[USER] {user_input}") 845 846 # Handle the actual interaction with complete streaming for all responses 847 result = asyncio.run(self._interact_async_core( 848 user_input=user_input, 849 quiet=quiet, 850 tools=tools, 851 stream=stream, 852 markdown=markdown, 853 output_callback=output_callback 854 )) 855 856 # Log completion for this interaction 857 self._log(f"[INTERACTION] Completed with {len(self.history)} total messages") 858 859 return result
Main universal gateway for all LLM interaction.
1794 def messages_add( 1795 self, 1796 role: str, 1797 content: Any, 1798 tool_info: Optional[Dict] = None, 1799 normalize: bool = True 1800 ) -> str: 1801 """ 1802 Add a message to the standardized session_history and then update SDK-specific history. 1803 1804 This method is the central point for all message additions to the conversation. 1805 1806 Args: 1807 role: The role of the message ("user", "assistant", "system", "tool") 1808 content: The message content (text or structured) 1809 tool_info: Optional tool-related metadata 1810 normalize: Whether to normalize history after adding this message 1811 1812 Returns: 1813 str: Unique ID of the added message 1814 """ 1815 # Generate a unique message ID 1816 message_id = str(uuid.uuid4()) 1817 1818 # Create the standardized message for session_history 1819 timestamp = datetime.now(timezone.utc).isoformat() 1820 1821 # Store system messages directly 1822 if role == "system": 1823 self.system = content 1824 1825 # Create standard format message 1826 standard_message = { 1827 "role": role, 1828 "content": content, 1829 "id": message_id, 1830 "timestamp": timestamp, 1831 "metadata": { 1832 "sdk": self.sdk 1833 } 1834 } 1835 1836 # Add tool info if provided 1837 if tool_info: 1838 standard_message["metadata"]["tool_info"] = tool_info 1839 1840 # Add to session_history 1841 if not hasattr(self, "session_history"): 1842 self.session_history = [] 1843 1844 self.session_history.append(standard_message) 1845 1846 # Save to persistent session if enabled 1847 if self.session_enabled and self.session_id: 1848 # Convert standard message to session-compatible format 1849 session_msg = { 1850 "role": role, 1851 "content": content, 1852 "id": message_id, 1853 "timestamp": timestamp 1854 } 1855 1856 # Add tool-related fields if present 1857 if tool_info: 1858 for key, value in tool_info.items(): 1859 session_msg[key] = value 1860 1861 # Store in session 1862 self.session.msg_insert(self.session_id, session_msg) 1863 1864 # Update the SDK-specific format in self.history by running the normalizer 1865 if normalize: 1866 # We only need to normalize the most recent message for efficiency 1867 # Pass a flag indicating we're just normalizing a new message 1868 self._normalizer(force=False, new_message_only=True) 1869 1870 # Log the added message 1871 self._log(f"[MESSAGE ADDED] {role}: {str(content)[:50]}...") 1872 1873 return message_id
Add a message to the standardized session_history and then update SDK-specific history.
This method is the central point for all message additions to the conversation.
Args: role: The role of the message ("user", "assistant", "system", "tool") content: The message content (text or structured) tool_info: Optional tool-related metadata normalize: Whether to normalize history after adding this message
Returns: str: Unique ID of the added message
1875 def messages_system(self, prompt: str): 1876 """Set or retrieve the current system prompt.""" 1877 if not isinstance(prompt, str) or not prompt: 1878 return self.system 1879 1880 # If the prompt hasn't changed, don't do anything 1881 if self.system == prompt: 1882 return self.system 1883 1884 # Update the system prompt 1885 old_system = self.system 1886 self.system = prompt 1887 1888 # For OpenAI, update or insert the system message in history 1889 if self.sdk == "openai": 1890 # Check if there's already a system message 1891 system_index = next((i for i, msg in enumerate(self.history) 1892 if msg.get("role") == "system"), None) 1893 1894 if system_index is not None: 1895 # Update existing system message 1896 self.history[system_index]["content"] = prompt 1897 else: 1898 # Insert new system message at the beginning 1899 self.history.insert(0, {"role": "system", "content": prompt}) 1900 1901 # For Anthropic, system message is not part of history, just save it for API calls 1902 1903 # Log to session only if prompt actually changed 1904 if self.session_enabled and self.session_id and old_system != prompt: 1905 self.session.msg_insert(self.session_id, {"role": "system", "content": prompt}) 1906 1907 return self.system
Set or retrieve the current system prompt.
1909 def messages(self) -> list: 1910 """Return full session messages (persisted or in-memory).""" 1911 if self.session_enabled and self.session_id: 1912 return self.session.load_full(self.session_id).get("messages", []) 1913 return self.session_history
Return full session messages (persisted or in-memory).
1915 def messages_length(self) -> int: 1916 """Calculate the total token count for the message history.""" 1917 if not self.encoding: 1918 return 0 1919 1920 total_tokens = 0 1921 for message in self.history: 1922 if message.get("content"): 1923 total_tokens += len(self.encoding.encode(message["content"])) 1924 if message.get("tool_calls"): 1925 for tool_call in message["tool_calls"]: 1926 if tool_call.get("function"): 1927 total_tokens += len(self.encoding.encode(tool_call["function"].get("name", ""))) 1928 total_tokens += len(self.encoding.encode(tool_call["function"].get("arguments", ""))) 1929 return total_tokens
Calculate the total token count for the message history.
1931 def session_load(self, session_id: Optional[str]): 1932 """Load and normalize messages for a specific session.""" 1933 self.session_id = session_id 1934 self._last_session_id = session_id 1935 1936 if self.session_enabled and session_id: 1937 try: 1938 # Load raw session data 1939 session_data = self.session.load_full(session_id) 1940 messages = session_data.get("messages", []) 1941 1942 # Convert session format to our standard format 1943 self.session_history = [] 1944 1945 # Track the most recent system message 1946 latest_system_msg = None 1947 1948 for msg in messages: 1949 # Extract fields 1950 role = msg.get("role", "user") 1951 content = msg.get("content", "") 1952 msg_id = msg.get("id", str(uuid.uuid4())) 1953 timestamp = msg.get("timestamp", datetime.now(timezone.utc).isoformat()) 1954 1955 # If this is a system message, track it but don't add to session_history yet 1956 if role == "system": 1957 if latest_system_msg is None or timestamp > latest_system_msg["timestamp"]: 1958 latest_system_msg = { 1959 "role": role, 1960 "content": content, 1961 "id": msg_id, 1962 "timestamp": timestamp, 1963 "metadata": {"sdk": self.sdk} 1964 } 1965 continue 1966 1967 # Build tool_info if present 1968 tool_info = None 1969 if any(key in msg for key in ["tool_use_id", "tool_call_id", "name", "arguments"]): 1970 tool_info = { 1971 "id": msg.get("tool_use_id") or msg.get("tool_call_id"), 1972 "name": msg.get("name", "unknown_tool"), 1973 "arguments": msg.get("arguments", {}) 1974 } 1975 1976 # Create standard message 1977 standard_msg = { 1978 "role": role, 1979 "content": content, 1980 "id": msg_id, 1981 "timestamp": timestamp, 1982 "metadata": { 1983 "sdk": self.sdk 1984 } 1985 } 1986 1987 if tool_info: 1988 standard_msg["metadata"]["tool_info"] = tool_info 1989 1990 self.session_history.append(standard_msg) 1991 1992 # If we found a system message, update the system property and add to history 1993 if latest_system_msg: 1994 self.system = latest_system_msg["content"] 1995 # Insert at the beginning of session_history 1996 self.session_history.insert(0, latest_system_msg) 1997 else: 1998 # If no system message was found, add the current system message 1999 self.messages_add(role="system", content=self.system) 2000 2001 # Normalize to current SDK format 2002 self._normalizer(force=True) 2003 2004 self._log(f"[SESSION] Switched to session '{session_id}'") 2005 except Exception as e: 2006 self.logger.error(f"Failed to load session '{session_id}': {e}") 2007 self.session_reset() 2008 else: 2009 # Reset to empty state with system message 2010 self.session_reset()
Load and normalize messages for a specific session.
2012 def session_reset(self): 2013 """ 2014 Reset the current session state and reinitialize to default system prompt. 2015 2016 Clears history, disables session ID tracking, and returns to in-memory mode. 2017 """ 2018 self.session_id = None 2019 self._last_session_id = None 2020 2021 # Clear histories 2022 self.session_history = [] 2023 self.history = [] 2024 2025 # Reapply the system message 2026 if hasattr(self, "system") and self.system: 2027 # Add to session_history 2028 self.messages_add(role="system", content=self.system) 2029 else: 2030 # Ensure we have a default system message 2031 self.system = "You are a helpful Assistant." 2032 self.messages_add(role="system", content=self.system) 2033 2034 self._log("[SESSION] Reset to in-memory mode")
Reset the current session state and reinitialize to default system prompt.
Clears history, disables session ID tracking, and returns to in-memory mode.
2433 def track_token_usage(self): 2434 """Track and return token usage across the conversation history. 2435 2436 Returns: 2437 dict: Dictionary containing current token counts, limits, and history. 2438 """ 2439 if not hasattr(self, "_token_history"): 2440 self._token_history = [] 2441 2442 # Count current tokens 2443 current_count = self._count_tokens(self.history) 2444 2445 # Add to history 2446 timestamp = datetime.now(timezone.utc).isoformat() 2447 self._token_history.append({ 2448 "timestamp": timestamp, 2449 "count": current_count, 2450 "limit": self.context_length, 2451 "provider": self.provider, 2452 "model": self.model 2453 }) 2454 2455 # Keep only the last 100 measurements to avoid unlimited growth 2456 if len(self._token_history) > 100: 2457 self._token_history = self._token_history[-100:] 2458 2459 # Return current tracking info 2460 return { 2461 "current": current_count, 2462 "limit": self.context_length, 2463 "percentage": round((current_count / self.context_length) * 100, 1) if self.context_length else 0, 2464 "history": self._token_history[-10:], # Return last 10 measurements 2465 "provider": self.provider, 2466 "model": self.model 2467 }
Track and return token usage across the conversation history.
Returns: dict: Dictionary containing current token counts, limits, and history.
2469 def get_message_token_breakdown(self): 2470 """Analyze token usage by message type and provide a detailed breakdown. 2471 2472 Returns: 2473 dict: Token usage broken down by message types and roles. 2474 """ 2475 breakdown = { 2476 "total": 0, 2477 "by_role": { 2478 "system": 0, 2479 "user": 0, 2480 "assistant": 0, 2481 "tool": 0 2482 }, 2483 "by_type": { 2484 "text": 0, 2485 "tool_calls": 0, 2486 "tool_results": 0 2487 }, 2488 "messages": [] 2489 } 2490 2491 # Analyze each message 2492 for i, msg in enumerate(self.history): 2493 msg_tokens = self._count_tokens([msg]) 2494 role = msg.get("role", "unknown") 2495 2496 # Track by role 2497 if role in breakdown["by_role"]: 2498 breakdown["by_role"][role] += msg_tokens 2499 2500 # Track by content type 2501 if role == "assistant" and msg.get("tool_calls"): 2502 breakdown["by_type"]["tool_calls"] += msg_tokens 2503 elif role == "tool": 2504 breakdown["by_type"]["tool_results"] += msg_tokens 2505 else: 2506 breakdown["by_type"]["text"] += msg_tokens 2507 2508 # Add individual message data 2509 breakdown["messages"].append({ 2510 "index": i, 2511 "role": role, 2512 "tokens": msg_tokens, 2513 "has_tools": bool(msg.get("tool_calls") or msg.get("tool_use") or 2514 (isinstance(msg.get("content"), list) and 2515 any(isinstance(c, dict) and c.get("type") in ["tool_use", "tool_result"] 2516 for c in msg.get("content", [])))) 2517 }) 2518 2519 # Update tota? 2520 breakdown["total"] += msg_tokens 2521 2522 return breakdown
Analyze token usage by message type and provide a detailed breakdown.
Returns: dict: Token usage broken down by message types and roles.
19class Session: 20 def __init__(self, directory: str = None): 21 """ 22 Initialize the session manager and ensure the session directory exists. 23 24 Args: 25 directory (str): Filesystem path for session storage. Must not be None or empty. 26 27 Raises: 28 ValueError: If directory is None or not a string. 29 OSError: If the directory cannot be created or accessed. 30 """ 31 if not directory: 32 raise ValueError("Session directory must be a valid non-empty string path.") 33 34 try: 35 self.path = Path(os.path.expanduser(directory)) 36 self.path.mkdir(parents=True, exist_ok=True) 37 except Exception as e: 38 raise OSError(f"Failed to initialize session directory '{directory}': {e}") 39 40 # --------------------------- 41 # Core CRUD 42 # --------------------------- 43 44 def list(self) -> List[Dict]: 45 """ 46 Return metadata for all sessions in the directory. 47 48 Returns: 49 List[Dict]: Sorted list of session metadata dictionaries. 50 """ 51 out = [] 52 for file in self.path.glob("*.json"): 53 try: 54 with open(file, "r") as f: 55 d = json.load(f) 56 out.append({ 57 "id": d.get("id"), 58 "name": d.get("name"), 59 "created": d.get("created"), 60 "tags": d.get("tags", []), 61 "summary": d.get("summary") 62 }) 63 except Exception: 64 continue 65 return sorted(out, key=lambda x: x["created"], reverse=True) 66 67 def create(self, name: str, tags: Optional[List[str]] = None) -> str: 68 """ 69 Create and persist a new session. 70 71 Args: 72 name (str): Name of the new session. 73 tags (List[str], optional): Optional list of tags. 74 75 Returns: 76 str: Unique session ID of the new session. 77 """ 78 sid = str(uuid.uuid4()) 79 session = { 80 "id": sid, 81 "name": name, 82 "created": datetime.now(timezone.utc).isoformat(), 83 "parent": None, 84 "branch_point": None, 85 "tags": tags or [], 86 "summary": None, 87 "messages": [] 88 } 89 self._save_file(sid, session) 90 return sid 91 92 def load(self, session_id: str) -> List[Dict]: 93 """ 94 Return OpenAI-compatible message list from a session. 95 96 Filters out internal keys and leaves only standard API-compatible fields. 97 98 Args: 99 session_id (str): ID of the session to load. 100 101 Returns: 102 List[Dict]: List of clean message dictionaries. 103 """ 104 session = self._read_file(session_id) 105 return [ 106 {k: v for k, v in m.items() if k in { 107 "role", "content", "tool_calls", "name", "function_call", "tool_call_id" 108 }} for m in session.get("messages", []) 109 ] 110 111 def load_full(self, session_id: str) -> Dict: 112 """ 113 Return the complete session file as-is. 114 115 Args: 116 session_id (str): ID of the session. 117 118 Returns: 119 Dict: Entire raw session data. 120 """ 121 return self._read_file(session_id) 122 123 def delete(self, session_id: str): 124 """ 125 Delete a session file from disk. 126 127 Args: 128 session_id (str): ID of the session to delete. 129 """ 130 file = self.path / f"{session_id}.json" 131 if file.exists(): 132 file.unlink() 133 134 def update(self, session_id: str, key: str, value: Any): 135 """ 136 Update a top-level key in a session file. 137 138 Args: 139 session_id (str): Session ID. 140 key (str): Field to update. 141 value (Any): New value for the field. 142 """ 143 session = self._read_file(session_id) 144 session[key] = value 145 self._save_file(session_id, session) 146 147 # --------------------------- 148 # Message Operations 149 # --------------------------- 150 151 def msg_insert(self, session_id: str, message: Dict) -> str: 152 """ 153 Insert a new message into a session. 154 155 Args: 156 session_id (str): Session ID. 157 message (Dict): Message dictionary to insert. 158 159 Returns: 160 str: ID of the inserted message. 161 """ 162 session = self._read_file(session_id) 163 entry = { 164 "id": str(uuid.uuid4()), 165 "timestamp": datetime.now(timezone.utc).isoformat(), 166 **message 167 } 168 session["messages"].append(entry) 169 self._save_file(session_id, session) 170 return entry["id"] 171 172 def msg_get(self, session_id: str, message_id: str) -> Optional[Dict]: 173 """ 174 Retrieve a specific message from a session. 175 176 Args: 177 session_id (str): Session ID. 178 message_id (str): ID of the message to retrieve. 179 180 Returns: 181 Optional[Dict]: The message if found, else None. 182 """ 183 session = self._read_file(session_id) 184 for msg in session.get("messages", []): 185 if msg.get("id") == message_id: 186 return msg 187 return None 188 189 def msg_index(self, session_id: str, message_id: str) -> Optional[int]: 190 """ 191 Get the index of a message within a session. 192 193 Args: 194 session_id (str): Session ID. 195 message_id (str): Message ID. 196 197 Returns: 198 Optional[int]: Index if found, else None. 199 """ 200 session = self._read_file(session_id) 201 for i, msg in enumerate(session.get("messages", [])): 202 if msg.get("id") == message_id: 203 return i 204 return None 205 206 def msg_update(self, session_id: str, message_id: str, new_content: str) -> bool: 207 """ 208 Update the content of a specific message. 209 210 Args: 211 session_id (str): Session ID. 212 message_id (str): Message ID. 213 new_content (str): New content for the message. 214 215 Returns: 216 bool: True if update succeeded, False otherwise. 217 """ 218 session = self._read_file(session_id) 219 for m in session["messages"]: 220 if m.get("id") == message_id: 221 m["content"] = new_content 222 self._save_file(session_id, session) 223 return True 224 return False 225 226 def msg_delete(self, session_id: str, message_id: str) -> bool: 227 """ 228 Delete a message from a session. 229 230 Args: 231 session_id (str): Session ID. 232 message_id (str): Message ID. 233 234 Returns: 235 bool: True if deletion occurred, False otherwise. 236 """ 237 session = self._read_file(session_id) 238 before = len(session["messages"]) 239 session["messages"] = [m for m in session["messages"] if m.get("id") != message_id] 240 self._save_file(session_id, session) 241 return len(session["messages"]) < before 242 243 # --------------------------- 244 # Branching & Summarization 245 # --------------------------- 246 247 def branch(self, from_id: str, message_id: str, new_name: str) -> str: 248 """ 249 Create a new session by branching from a specific message. 250 251 Args: 252 from_id (str): ID of the base session. 253 message_id (str): Message ID to branch from. 254 new_name (str): Name of the new session. 255 256 Returns: 257 str: ID of the new session. 258 """ 259 base = self._read_file(from_id) 260 index = next((i for i, m in enumerate(base["messages"]) if m.get("id") == message_id), None) 261 if index is None: 262 raise ValueError(f"Message ID {message_id} not found in session {from_id}.") 263 partial = base["messages"][:index + 1] 264 new_id = str(uuid.uuid4()) 265 branch = { 266 "id": new_id, 267 "name": new_name, 268 "created": datetime.now(timezone.utc).isoformat(), 269 "parent": from_id, 270 "branch_point": message_id, 271 "tags": base.get("tags", []), 272 "summary": None, 273 "messages": partial 274 } 275 self._save_file(new_id, branch) 276 return new_id 277 278 def summarize(self, interactor, session_id: str) -> str: 279 """ 280 Summarize the conversation in a session using the Interactor. 281 282 Args: 283 interactor: An instance of the Interactor class with messaging methods. 284 session_id (str): ID of the session to summarize. 285 286 Returns: 287 str: Generated summary string. 288 """ 289 clean_messages = self.load(session_id) 290 summary_prompt = "Summarize the following interaction in 3-4 sentences." 291 interactor.messages_flush() 292 interactor.messages_add(role="system", content=summary_prompt) 293 for msg in clean_messages: 294 interactor.messages_add(**msg) 295 summary = interactor.interact(user_input=None, stream=False, quiet=True) 296 self.update(session_id, "summary", summary) 297 return summary 298 299 # --------------------------- 300 # Search Capabilities 301 # --------------------------- 302 303 def search(self, query: str, session_id: Optional[str] = None) -> List[Dict]: 304 """ 305 Search messages for a text query across sessions or within a specific session. 306 307 Args: 308 query (str): Search term (case-insensitive). 309 session_id (str, optional): If specified, restrict search to that session. 310 311 Returns: 312 List[Dict]: List of session dicts containing only matching messages. 313 """ 314 result = [] 315 sessions = [session_id] if session_id else [f.stem for f in self.path.glob("*.json")] 316 317 for sid in sessions: 318 try: 319 data = self._read_file(sid) 320 matching = [ 321 msg for msg in data.get("messages", []) 322 if query.lower() in (msg.get("content") or "").lower() 323 ] 324 if matching: 325 result.append({ 326 "id": data["id"], 327 "name": data["name"], 328 "created": data["created"], 329 "parent": data.get("parent"), 330 "branch_point": data.get("branch_point"), 331 "tags": data.get("tags", []), 332 "summary": data.get("summary"), 333 "messages": matching 334 }) 335 except Exception: 336 continue 337 return result 338 339 def search_meta(self, query: str) -> List[Dict]: 340 """ 341 Search session metadata fields: name, tags, and summary. 342 343 Args: 344 query (str): Case-insensitive search term. 345 346 Returns: 347 List[Dict]: Matching sessions (without message contents). 348 """ 349 result = [] 350 for file in self.path.glob("*.json"): 351 try: 352 data = self._read_file(file.stem) 353 searchable = " ".join([ 354 data.get("name", ""), 355 data.get("summary", "") or "", 356 " ".join(data.get("tags", [])) 357 ]).lower() 358 if query.lower() in searchable: 359 result.append({ 360 "id": data["id"], 361 "name": data["name"], 362 "created": data["created"], 363 "parent": data.get("parent"), 364 "branch_point": data.get("branch_point"), 365 "tags": data.get("tags", []), 366 "summary": data.get("summary") 367 }) 368 except Exception: 369 continue 370 return result 371 372 # --------------------------- 373 # Internal I/O 374 # --------------------------- 375 376 def _read_file(self, session_id: str) -> Dict: 377 """ 378 Load a session file from disk. 379 380 Args: 381 session_id (str): ID of the session file to read. 382 383 Returns: 384 Dict: Parsed session data. 385 386 Raises: 387 FileNotFoundError: If the session file doesn't exist. 388 """ 389 file = self.path / f"{session_id}.json" 390 if not file.exists(): 391 raise FileNotFoundError(f"Session {session_id} not found.") 392 with open(file, "r") as f: 393 return json.load(f) 394 395 def _save_file(self, session_id: str, data: Dict): 396 """ 397 Write session data to disk. 398 399 Args: 400 session_id (str): ID of the session file to write. 401 data (Dict): Session content to save. 402 """ 403 file = self.path / f"{session_id}.json" 404 with open(file, "w") as f: 405 json.dump(data, f, indent=2)
20 def __init__(self, directory: str = None): 21 """ 22 Initialize the session manager and ensure the session directory exists. 23 24 Args: 25 directory (str): Filesystem path for session storage. Must not be None or empty. 26 27 Raises: 28 ValueError: If directory is None or not a string. 29 OSError: If the directory cannot be created or accessed. 30 """ 31 if not directory: 32 raise ValueError("Session directory must be a valid non-empty string path.") 33 34 try: 35 self.path = Path(os.path.expanduser(directory)) 36 self.path.mkdir(parents=True, exist_ok=True) 37 except Exception as e: 38 raise OSError(f"Failed to initialize session directory '{directory}': {e}")
Initialize the session manager and ensure the session directory exists.
Args: directory (str): Filesystem path for session storage. Must not be None or empty.
Raises: ValueError: If directory is None or not a string. OSError: If the directory cannot be created or accessed.
44 def list(self) -> List[Dict]: 45 """ 46 Return metadata for all sessions in the directory. 47 48 Returns: 49 List[Dict]: Sorted list of session metadata dictionaries. 50 """ 51 out = [] 52 for file in self.path.glob("*.json"): 53 try: 54 with open(file, "r") as f: 55 d = json.load(f) 56 out.append({ 57 "id": d.get("id"), 58 "name": d.get("name"), 59 "created": d.get("created"), 60 "tags": d.get("tags", []), 61 "summary": d.get("summary") 62 }) 63 except Exception: 64 continue 65 return sorted(out, key=lambda x: x["created"], reverse=True)
Return metadata for all sessions in the directory.
Returns: List[Dict]: Sorted list of session metadata dictionaries.
67 def create(self, name: str, tags: Optional[List[str]] = None) -> str: 68 """ 69 Create and persist a new session. 70 71 Args: 72 name (str): Name of the new session. 73 tags (List[str], optional): Optional list of tags. 74 75 Returns: 76 str: Unique session ID of the new session. 77 """ 78 sid = str(uuid.uuid4()) 79 session = { 80 "id": sid, 81 "name": name, 82 "created": datetime.now(timezone.utc).isoformat(), 83 "parent": None, 84 "branch_point": None, 85 "tags": tags or [], 86 "summary": None, 87 "messages": [] 88 } 89 self._save_file(sid, session) 90 return sid
Create and persist a new session.
Args: name (str): Name of the new session. tags (List[str], optional): Optional list of tags.
Returns: str: Unique session ID of the new session.
92 def load(self, session_id: str) -> List[Dict]: 93 """ 94 Return OpenAI-compatible message list from a session. 95 96 Filters out internal keys and leaves only standard API-compatible fields. 97 98 Args: 99 session_id (str): ID of the session to load. 100 101 Returns: 102 List[Dict]: List of clean message dictionaries. 103 """ 104 session = self._read_file(session_id) 105 return [ 106 {k: v for k, v in m.items() if k in { 107 "role", "content", "tool_calls", "name", "function_call", "tool_call_id" 108 }} for m in session.get("messages", []) 109 ]
Return OpenAI-compatible message list from a session.
Filters out internal keys and leaves only standard API-compatible fields.
Args: session_id (str): ID of the session to load.
Returns: List[Dict]: List of clean message dictionaries.
111 def load_full(self, session_id: str) -> Dict: 112 """ 113 Return the complete session file as-is. 114 115 Args: 116 session_id (str): ID of the session. 117 118 Returns: 119 Dict: Entire raw session data. 120 """ 121 return self._read_file(session_id)
Return the complete session file as-is.
Args: session_id (str): ID of the session.
Returns: Dict: Entire raw session data.
123 def delete(self, session_id: str): 124 """ 125 Delete a session file from disk. 126 127 Args: 128 session_id (str): ID of the session to delete. 129 """ 130 file = self.path / f"{session_id}.json" 131 if file.exists(): 132 file.unlink()
Delete a session file from disk.
Args: session_id (str): ID of the session to delete.
134 def update(self, session_id: str, key: str, value: Any): 135 """ 136 Update a top-level key in a session file. 137 138 Args: 139 session_id (str): Session ID. 140 key (str): Field to update. 141 value (Any): New value for the field. 142 """ 143 session = self._read_file(session_id) 144 session[key] = value 145 self._save_file(session_id, session)
Update a top-level key in a session file.
Args: session_id (str): Session ID. key (str): Field to update. value (Any): New value for the field.
151 def msg_insert(self, session_id: str, message: Dict) -> str: 152 """ 153 Insert a new message into a session. 154 155 Args: 156 session_id (str): Session ID. 157 message (Dict): Message dictionary to insert. 158 159 Returns: 160 str: ID of the inserted message. 161 """ 162 session = self._read_file(session_id) 163 entry = { 164 "id": str(uuid.uuid4()), 165 "timestamp": datetime.now(timezone.utc).isoformat(), 166 **message 167 } 168 session["messages"].append(entry) 169 self._save_file(session_id, session) 170 return entry["id"]
Insert a new message into a session.
Args: session_id (str): Session ID. message (Dict): Message dictionary to insert.
Returns: str: ID of the inserted message.
172 def msg_get(self, session_id: str, message_id: str) -> Optional[Dict]: 173 """ 174 Retrieve a specific message from a session. 175 176 Args: 177 session_id (str): Session ID. 178 message_id (str): ID of the message to retrieve. 179 180 Returns: 181 Optional[Dict]: The message if found, else None. 182 """ 183 session = self._read_file(session_id) 184 for msg in session.get("messages", []): 185 if msg.get("id") == message_id: 186 return msg 187 return None
Retrieve a specific message from a session.
Args: session_id (str): Session ID. message_id (str): ID of the message to retrieve.
Returns: Optional[Dict]: The message if found, else None.
189 def msg_index(self, session_id: str, message_id: str) -> Optional[int]: 190 """ 191 Get the index of a message within a session. 192 193 Args: 194 session_id (str): Session ID. 195 message_id (str): Message ID. 196 197 Returns: 198 Optional[int]: Index if found, else None. 199 """ 200 session = self._read_file(session_id) 201 for i, msg in enumerate(session.get("messages", [])): 202 if msg.get("id") == message_id: 203 return i 204 return None
Get the index of a message within a session.
Args: session_id (str): Session ID. message_id (str): Message ID.
Returns: Optional[int]: Index if found, else None.
206 def msg_update(self, session_id: str, message_id: str, new_content: str) -> bool: 207 """ 208 Update the content of a specific message. 209 210 Args: 211 session_id (str): Session ID. 212 message_id (str): Message ID. 213 new_content (str): New content for the message. 214 215 Returns: 216 bool: True if update succeeded, False otherwise. 217 """ 218 session = self._read_file(session_id) 219 for m in session["messages"]: 220 if m.get("id") == message_id: 221 m["content"] = new_content 222 self._save_file(session_id, session) 223 return True 224 return False
Update the content of a specific message.
Args: session_id (str): Session ID. message_id (str): Message ID. new_content (str): New content for the message.
Returns: bool: True if update succeeded, False otherwise.
226 def msg_delete(self, session_id: str, message_id: str) -> bool: 227 """ 228 Delete a message from a session. 229 230 Args: 231 session_id (str): Session ID. 232 message_id (str): Message ID. 233 234 Returns: 235 bool: True if deletion occurred, False otherwise. 236 """ 237 session = self._read_file(session_id) 238 before = len(session["messages"]) 239 session["messages"] = [m for m in session["messages"] if m.get("id") != message_id] 240 self._save_file(session_id, session) 241 return len(session["messages"]) < before
Delete a message from a session.
Args: session_id (str): Session ID. message_id (str): Message ID.
Returns: bool: True if deletion occurred, False otherwise.
247 def branch(self, from_id: str, message_id: str, new_name: str) -> str: 248 """ 249 Create a new session by branching from a specific message. 250 251 Args: 252 from_id (str): ID of the base session. 253 message_id (str): Message ID to branch from. 254 new_name (str): Name of the new session. 255 256 Returns: 257 str: ID of the new session. 258 """ 259 base = self._read_file(from_id) 260 index = next((i for i, m in enumerate(base["messages"]) if m.get("id") == message_id), None) 261 if index is None: 262 raise ValueError(f"Message ID {message_id} not found in session {from_id}.") 263 partial = base["messages"][:index + 1] 264 new_id = str(uuid.uuid4()) 265 branch = { 266 "id": new_id, 267 "name": new_name, 268 "created": datetime.now(timezone.utc).isoformat(), 269 "parent": from_id, 270 "branch_point": message_id, 271 "tags": base.get("tags", []), 272 "summary": None, 273 "messages": partial 274 } 275 self._save_file(new_id, branch) 276 return new_id
Create a new session by branching from a specific message.
Args: from_id (str): ID of the base session. message_id (str): Message ID to branch from. new_name (str): Name of the new session.
Returns: str: ID of the new session.
278 def summarize(self, interactor, session_id: str) -> str: 279 """ 280 Summarize the conversation in a session using the Interactor. 281 282 Args: 283 interactor: An instance of the Interactor class with messaging methods. 284 session_id (str): ID of the session to summarize. 285 286 Returns: 287 str: Generated summary string. 288 """ 289 clean_messages = self.load(session_id) 290 summary_prompt = "Summarize the following interaction in 3-4 sentences." 291 interactor.messages_flush() 292 interactor.messages_add(role="system", content=summary_prompt) 293 for msg in clean_messages: 294 interactor.messages_add(**msg) 295 summary = interactor.interact(user_input=None, stream=False, quiet=True) 296 self.update(session_id, "summary", summary) 297 return summary
Summarize the conversation in a session using the Interactor.
Args: interactor: An instance of the Interactor class with messaging methods. session_id (str): ID of the session to summarize.
Returns: str: Generated summary string.
303 def search(self, query: str, session_id: Optional[str] = None) -> List[Dict]: 304 """ 305 Search messages for a text query across sessions or within a specific session. 306 307 Args: 308 query (str): Search term (case-insensitive). 309 session_id (str, optional): If specified, restrict search to that session. 310 311 Returns: 312 List[Dict]: List of session dicts containing only matching messages. 313 """ 314 result = [] 315 sessions = [session_id] if session_id else [f.stem for f in self.path.glob("*.json")] 316 317 for sid in sessions: 318 try: 319 data = self._read_file(sid) 320 matching = [ 321 msg for msg in data.get("messages", []) 322 if query.lower() in (msg.get("content") or "").lower() 323 ] 324 if matching: 325 result.append({ 326 "id": data["id"], 327 "name": data["name"], 328 "created": data["created"], 329 "parent": data.get("parent"), 330 "branch_point": data.get("branch_point"), 331 "tags": data.get("tags", []), 332 "summary": data.get("summary"), 333 "messages": matching 334 }) 335 except Exception: 336 continue 337 return result
Search messages for a text query across sessions or within a specific session.
Args: query (str): Search term (case-insensitive). session_id (str, optional): If specified, restrict search to that session.
Returns: List[Dict]: List of session dicts containing only matching messages.
339 def search_meta(self, query: str) -> List[Dict]: 340 """ 341 Search session metadata fields: name, tags, and summary. 342 343 Args: 344 query (str): Case-insensitive search term. 345 346 Returns: 347 List[Dict]: Matching sessions (without message contents). 348 """ 349 result = [] 350 for file in self.path.glob("*.json"): 351 try: 352 data = self._read_file(file.stem) 353 searchable = " ".join([ 354 data.get("name", ""), 355 data.get("summary", "") or "", 356 " ".join(data.get("tags", [])) 357 ]).lower() 358 if query.lower() in searchable: 359 result.append({ 360 "id": data["id"], 361 "name": data["name"], 362 "created": data["created"], 363 "parent": data.get("parent"), 364 "branch_point": data.get("branch_point"), 365 "tags": data.get("tags", []), 366 "summary": data.get("summary") 367 }) 368 except Exception: 369 continue 370 return result
Search session metadata fields: name, tags, and summary.
Args: query (str): Case-insensitive search term.
Returns: List[Dict]: Matching sessions (without message contents).