kiln_ai.adapters.ml_model_list
1import os 2from dataclasses import dataclass 3from enum import Enum 4from os import getenv 5from typing import Any, Dict, List, NoReturn 6 7import httpx 8import requests 9from langchain_aws import ChatBedrockConverse 10from langchain_core.language_models.chat_models import BaseChatModel 11from langchain_groq import ChatGroq 12from langchain_ollama import ChatOllama 13from langchain_openai import ChatOpenAI 14from pydantic import BaseModel 15 16from ..utils.config import Config 17 18""" 19Provides model configuration and management for various LLM providers and models. 20This module handles the integration with different AI model providers and their respective models, 21including configuration, validation, and instantiation of language models. 22""" 23 24 25class ModelProviderName(str, Enum): 26 """ 27 Enumeration of supported AI model providers. 28 """ 29 30 openai = "openai" 31 groq = "groq" 32 amazon_bedrock = "amazon_bedrock" 33 ollama = "ollama" 34 openrouter = "openrouter" 35 36 37class ModelFamily(str, Enum): 38 """ 39 Enumeration of supported model families/architectures. 40 """ 41 42 gpt = "gpt" 43 llama = "llama" 44 phi = "phi" 45 mistral = "mistral" 46 gemma = "gemma" 47 gemini = "gemini" 48 claude = "claude" 49 50 51# Where models have instruct and raw versions, instruct is default and raw is specified 52class ModelName(str, Enum): 53 """ 54 Enumeration of specific model versions supported by the system. 55 Where models have instruct and raw versions, instruct is default and raw is specified. 56 """ 57 58 llama_3_1_8b = "llama_3_1_8b" 59 llama_3_1_70b = "llama_3_1_70b" 60 llama_3_1_405b = "llama_3_1_405b" 61 llama_3_2_3b = "llama_3_2_3b" 62 llama_3_2_11b = "llama_3_2_11b" 63 llama_3_2_90b = "llama_3_2_90b" 64 gpt_4o_mini = "gpt_4o_mini" 65 gpt_4o = "gpt_4o" 66 phi_3_5 = "phi_3_5" 67 mistral_large = "mistral_large" 68 mistral_nemo = "mistral_nemo" 69 gemma_2_2b = "gemma_2_2b" 70 gemma_2_9b = "gemma_2_9b" 71 gemma_2_27b = "gemma_2_27b" 72 claude_3_5_haiku = "claude_3_5_haiku" 73 claude_3_5_sonnet = "claude_3_5_sonnet" 74 gemini_1_5_flash = "gemini_1_5_flash" 75 gemini_1_5_flash_8b = "gemini_1_5_flash_8b" 76 gemini_1_5_pro = "gemini_1_5_pro" 77 nemotron_70b = "nemotron_70b" 78 79 80class KilnModelProvider(BaseModel): 81 """ 82 Configuration for a specific model provider. 83 84 Attributes: 85 name: The provider's identifier 86 supports_structured_output: Whether the provider supports structured output formats 87 provider_options: Additional provider-specific configuration options 88 """ 89 90 name: ModelProviderName 91 supports_structured_output: bool = True 92 supports_data_gen: bool = True 93 provider_finetune_id: str | None = None 94 provider_options: Dict = {} 95 96 97class KilnModel(BaseModel): 98 """ 99 Configuration for a specific AI model. 100 101 Attributes: 102 family: The model's architecture family 103 name: The model's identifier 104 friendly_name: Human-readable name for the model 105 providers: List of providers that offer this model 106 supports_structured_output: Whether the model supports structured output formats 107 """ 108 109 family: str 110 name: str 111 friendly_name: str 112 providers: List[KilnModelProvider] 113 supports_structured_output: bool = True 114 115 116built_in_models: List[KilnModel] = [ 117 # GPT 4o Mini 118 KilnModel( 119 family=ModelFamily.gpt, 120 name=ModelName.gpt_4o_mini, 121 friendly_name="GPT 4o Mini", 122 providers=[ 123 KilnModelProvider( 124 name=ModelProviderName.openai, 125 provider_options={"model": "gpt-4o-mini"}, 126 provider_finetune_id="gpt-4o-mini-2024-07-18", 127 ), 128 KilnModelProvider( 129 name=ModelProviderName.openrouter, 130 provider_options={"model": "openai/gpt-4o-mini"}, 131 ), 132 ], 133 ), 134 # GPT 4o 135 KilnModel( 136 family=ModelFamily.gpt, 137 name=ModelName.gpt_4o, 138 friendly_name="GPT 4o", 139 providers=[ 140 KilnModelProvider( 141 name=ModelProviderName.openai, 142 provider_options={"model": "gpt-4o"}, 143 provider_finetune_id="gpt-4o-2024-08-06", 144 ), 145 KilnModelProvider( 146 name=ModelProviderName.openrouter, 147 provider_options={"model": "openai/gpt-4o-2024-08-06"}, 148 ), 149 ], 150 ), 151 # Claude 3.5 Haiku 152 KilnModel( 153 family=ModelFamily.claude, 154 name=ModelName.claude_3_5_haiku, 155 friendly_name="Claude 3.5 Haiku", 156 providers=[ 157 KilnModelProvider( 158 name=ModelProviderName.openrouter, 159 provider_options={"model": "anthropic/claude-3-5-haiku"}, 160 ), 161 ], 162 ), 163 # Claude 3.5 Sonnet 164 KilnModel( 165 family=ModelFamily.claude, 166 name=ModelName.claude_3_5_sonnet, 167 friendly_name="Claude 3.5 Sonnet", 168 providers=[ 169 KilnModelProvider( 170 name=ModelProviderName.openrouter, 171 provider_options={"model": "anthropic/claude-3.5-sonnet"}, 172 ), 173 ], 174 ), 175 # Gemini 1.5 Pro 176 KilnModel( 177 family=ModelFamily.gemini, 178 name=ModelName.gemini_1_5_pro, 179 friendly_name="Gemini 1.5 Pro", 180 providers=[ 181 KilnModelProvider( 182 name=ModelProviderName.openrouter, 183 supports_structured_output=False, # it should, but doesn't work on openrouter 184 supports_data_gen=False, # doesn't work on openrouter 185 provider_options={"model": "google/gemini-pro-1.5"}, 186 ), 187 ], 188 ), 189 # Gemini 1.5 Flash 190 KilnModel( 191 family=ModelFamily.gemini, 192 name=ModelName.gemini_1_5_flash, 193 friendly_name="Gemini 1.5 Flash", 194 providers=[ 195 KilnModelProvider( 196 name=ModelProviderName.openrouter, 197 supports_data_gen=False, 198 provider_options={"model": "google/gemini-flash-1.5"}, 199 ), 200 ], 201 ), 202 # Gemini 1.5 Flash 8B 203 KilnModel( 204 family=ModelFamily.gemini, 205 name=ModelName.gemini_1_5_flash_8b, 206 friendly_name="Gemini 1.5 Flash 8B", 207 providers=[ 208 KilnModelProvider( 209 name=ModelProviderName.openrouter, 210 supports_structured_output=False, 211 supports_data_gen=False, 212 provider_options={"model": "google/gemini-flash-1.5-8b"}, 213 ), 214 ], 215 ), 216 # Nemotron 70B 217 KilnModel( 218 family=ModelFamily.llama, 219 name=ModelName.nemotron_70b, 220 friendly_name="Nemotron 70B", 221 providers=[ 222 KilnModelProvider( 223 name=ModelProviderName.openrouter, 224 supports_structured_output=False, 225 supports_data_gen=False, 226 provider_options={"model": "nvidia/llama-3.1-nemotron-70b-instruct"}, 227 ), 228 ], 229 ), 230 # Llama 3.1-8b 231 KilnModel( 232 family=ModelFamily.llama, 233 name=ModelName.llama_3_1_8b, 234 friendly_name="Llama 3.1 8B", 235 providers=[ 236 KilnModelProvider( 237 name=ModelProviderName.groq, 238 provider_options={"model": "llama-3.1-8b-instant"}, 239 ), 240 KilnModelProvider( 241 name=ModelProviderName.amazon_bedrock, 242 supports_structured_output=False, 243 supports_data_gen=False, 244 provider_options={ 245 "model": "meta.llama3-1-8b-instruct-v1:0", 246 "region_name": "us-west-2", # Llama 3.1 only in west-2 247 }, 248 ), 249 KilnModelProvider( 250 name=ModelProviderName.ollama, 251 supports_data_gen=False, 252 provider_options={ 253 "model": "llama3.1:8b", 254 "model_aliases": ["llama3.1"], # 8b is default 255 }, 256 ), 257 KilnModelProvider( 258 name=ModelProviderName.openrouter, 259 supports_structured_output=False, 260 supports_data_gen=False, 261 provider_options={"model": "meta-llama/llama-3.1-8b-instruct"}, 262 ), 263 ], 264 ), 265 # Llama 3.1 70b 266 KilnModel( 267 family=ModelFamily.llama, 268 name=ModelName.llama_3_1_70b, 269 friendly_name="Llama 3.1 70B", 270 providers=[ 271 KilnModelProvider( 272 name=ModelProviderName.groq, 273 provider_options={"model": "llama-3.1-70b-versatile"}, 274 ), 275 KilnModelProvider( 276 name=ModelProviderName.amazon_bedrock, 277 # not sure how AWS manages to break this, but it's not working 278 supports_structured_output=False, 279 supports_data_gen=False, 280 provider_options={ 281 "model": "meta.llama3-1-70b-instruct-v1:0", 282 "region_name": "us-west-2", # Llama 3.1 only in west-2 283 }, 284 ), 285 KilnModelProvider( 286 name=ModelProviderName.openrouter, 287 provider_options={"model": "meta-llama/llama-3.1-70b-instruct"}, 288 ), 289 KilnModelProvider( 290 name=ModelProviderName.ollama, 291 provider_options={"model": "llama3.1:70b"}, 292 ), 293 ], 294 ), 295 # Llama 3.1 405b 296 KilnModel( 297 family=ModelFamily.llama, 298 name=ModelName.llama_3_1_405b, 299 friendly_name="Llama 3.1 405B", 300 providers=[ 301 KilnModelProvider( 302 name=ModelProviderName.amazon_bedrock, 303 supports_data_gen=False, 304 provider_options={ 305 "model": "meta.llama3-1-405b-instruct-v1:0", 306 "region_name": "us-west-2", # Llama 3.1 only in west-2 307 }, 308 ), 309 KilnModelProvider( 310 name=ModelProviderName.ollama, 311 provider_options={"model": "llama3.1:405b"}, 312 ), 313 KilnModelProvider( 314 name=ModelProviderName.openrouter, 315 provider_options={"model": "meta-llama/llama-3.1-405b-instruct"}, 316 ), 317 ], 318 ), 319 # Mistral Nemo 320 KilnModel( 321 family=ModelFamily.mistral, 322 name=ModelName.mistral_nemo, 323 friendly_name="Mistral Nemo", 324 providers=[ 325 KilnModelProvider( 326 name=ModelProviderName.openrouter, 327 provider_options={"model": "mistralai/mistral-nemo"}, 328 ), 329 ], 330 ), 331 # Mistral Large 332 KilnModel( 333 family=ModelFamily.mistral, 334 name=ModelName.mistral_large, 335 friendly_name="Mistral Large", 336 providers=[ 337 KilnModelProvider( 338 name=ModelProviderName.amazon_bedrock, 339 provider_options={ 340 "model": "mistral.mistral-large-2407-v1:0", 341 "region_name": "us-west-2", # only in west-2 342 }, 343 ), 344 KilnModelProvider( 345 name=ModelProviderName.openrouter, 346 provider_options={"model": "mistralai/mistral-large"}, 347 ), 348 KilnModelProvider( 349 name=ModelProviderName.ollama, 350 provider_options={"model": "mistral-large"}, 351 ), 352 ], 353 ), 354 # Llama 3.2 3B 355 KilnModel( 356 family=ModelFamily.llama, 357 name=ModelName.llama_3_2_3b, 358 friendly_name="Llama 3.2 3B", 359 providers=[ 360 KilnModelProvider( 361 name=ModelProviderName.openrouter, 362 supports_structured_output=False, 363 supports_data_gen=False, 364 provider_options={"model": "meta-llama/llama-3.2-3b-instruct"}, 365 ), 366 KilnModelProvider( 367 name=ModelProviderName.ollama, 368 supports_structured_output=False, 369 supports_data_gen=False, 370 provider_options={"model": "llama3.2"}, 371 ), 372 ], 373 ), 374 # Llama 3.2 11B 375 KilnModel( 376 family=ModelFamily.llama, 377 name=ModelName.llama_3_2_11b, 378 friendly_name="Llama 3.2 11B", 379 providers=[ 380 KilnModelProvider( 381 name=ModelProviderName.openrouter, 382 supports_structured_output=False, 383 supports_data_gen=False, 384 provider_options={"model": "meta-llama/llama-3.2-11b-vision-instruct"}, 385 ), 386 KilnModelProvider( 387 name=ModelProviderName.ollama, 388 supports_structured_output=False, 389 supports_data_gen=False, 390 provider_options={"model": "llama3.2-vision"}, 391 ), 392 ], 393 ), 394 # Llama 3.2 90B 395 KilnModel( 396 family=ModelFamily.llama, 397 name=ModelName.llama_3_2_90b, 398 friendly_name="Llama 3.2 90B", 399 providers=[ 400 KilnModelProvider( 401 name=ModelProviderName.openrouter, 402 supports_structured_output=False, 403 supports_data_gen=False, 404 provider_options={"model": "meta-llama/llama-3.2-90b-vision-instruct"}, 405 ), 406 KilnModelProvider( 407 name=ModelProviderName.ollama, 408 supports_structured_output=False, 409 supports_data_gen=False, 410 provider_options={"model": "llama3.2-vision:90b"}, 411 ), 412 ], 413 ), 414 # Phi 3.5 415 KilnModel( 416 family=ModelFamily.phi, 417 name=ModelName.phi_3_5, 418 friendly_name="Phi 3.5", 419 supports_structured_output=False, 420 providers=[ 421 KilnModelProvider( 422 name=ModelProviderName.ollama, 423 supports_structured_output=False, 424 supports_data_gen=False, 425 provider_options={"model": "phi3.5"}, 426 ), 427 KilnModelProvider( 428 name=ModelProviderName.openrouter, 429 supports_structured_output=False, 430 supports_data_gen=False, 431 provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"}, 432 ), 433 ], 434 ), 435 # Gemma 2 2.6b 436 KilnModel( 437 family=ModelFamily.gemma, 438 name=ModelName.gemma_2_2b, 439 friendly_name="Gemma 2 2B", 440 supports_structured_output=False, 441 providers=[ 442 KilnModelProvider( 443 name=ModelProviderName.ollama, 444 supports_structured_output=False, 445 supports_data_gen=False, 446 provider_options={ 447 "model": "gemma2:2b", 448 }, 449 ), 450 ], 451 ), 452 # Gemma 2 9b 453 KilnModel( 454 family=ModelFamily.gemma, 455 name=ModelName.gemma_2_9b, 456 friendly_name="Gemma 2 9B", 457 supports_structured_output=False, 458 providers=[ 459 KilnModelProvider( 460 name=ModelProviderName.ollama, 461 supports_data_gen=False, 462 provider_options={ 463 "model": "gemma2:9b", 464 }, 465 ), 466 KilnModelProvider( 467 name=ModelProviderName.openrouter, 468 supports_data_gen=False, 469 provider_options={"model": "google/gemma-2-9b-it"}, 470 ), 471 ], 472 ), 473 # Gemma 2 27b 474 KilnModel( 475 family=ModelFamily.gemma, 476 name=ModelName.gemma_2_27b, 477 friendly_name="Gemma 2 27B", 478 supports_structured_output=False, 479 providers=[ 480 KilnModelProvider( 481 name=ModelProviderName.ollama, 482 supports_data_gen=False, 483 provider_options={ 484 "model": "gemma2:27b", 485 }, 486 ), 487 KilnModelProvider( 488 name=ModelProviderName.openrouter, 489 supports_data_gen=False, 490 provider_options={"model": "google/gemma-2-27b-it"}, 491 ), 492 ], 493 ), 494] 495 496 497def get_model_and_provider( 498 model_name: str, provider_name: str 499) -> tuple[KilnModel | None, KilnModelProvider | None]: 500 model = next(filter(lambda m: m.name == model_name, built_in_models), None) 501 if model is None: 502 return None, None 503 provider = next(filter(lambda p: p.name == provider_name, model.providers), None) 504 # all or nothing 505 if provider is None or model is None: 506 return None, None 507 return model, provider 508 509 510def provider_name_from_id(id: str) -> str: 511 """ 512 Converts a provider ID to its human-readable name. 513 514 Args: 515 id: The provider identifier string 516 517 Returns: 518 The human-readable name of the provider 519 520 Raises: 521 ValueError: If the provider ID is invalid or unhandled 522 """ 523 if id in ModelProviderName.__members__: 524 enum_id = ModelProviderName(id) 525 match enum_id: 526 case ModelProviderName.amazon_bedrock: 527 return "Amazon Bedrock" 528 case ModelProviderName.openrouter: 529 return "OpenRouter" 530 case ModelProviderName.groq: 531 return "Groq" 532 case ModelProviderName.ollama: 533 return "Ollama" 534 case ModelProviderName.openai: 535 return "OpenAI" 536 case _: 537 # triggers pyright warning if I miss a case 538 raise_exhaustive_error(enum_id) 539 540 return "Unknown provider: " + id 541 542 543def raise_exhaustive_error(value: NoReturn) -> NoReturn: 544 raise ValueError(f"Unhandled enum value: {value}") 545 546 547@dataclass 548class ModelProviderWarning: 549 required_config_keys: List[str] 550 message: str 551 552 553provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = { 554 ModelProviderName.amazon_bedrock: ModelProviderWarning( 555 required_config_keys=["bedrock_access_key", "bedrock_secret_key"], 556 message="Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview", 557 ), 558 ModelProviderName.openrouter: ModelProviderWarning( 559 required_config_keys=["open_router_api_key"], 560 message="Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys", 561 ), 562 ModelProviderName.groq: ModelProviderWarning( 563 required_config_keys=["groq_api_key"], 564 message="Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys", 565 ), 566 ModelProviderName.openai: ModelProviderWarning( 567 required_config_keys=["open_ai_api_key"], 568 message="Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys", 569 ), 570} 571 572 573async def provider_enabled(provider_name: ModelProviderName) -> bool: 574 if provider_name == ModelProviderName.ollama: 575 try: 576 tags = await get_ollama_connection() 577 return tags is not None and len(tags.models) > 0 578 except Exception: 579 return False 580 581 provider_warning = provider_warnings.get(provider_name) 582 if provider_warning is None: 583 return False 584 for required_key in provider_warning.required_config_keys: 585 if get_config_value(required_key) is None: 586 return False 587 return True 588 589 590def get_config_value(key: str): 591 try: 592 return Config.shared().__getattr__(key) 593 except AttributeError: 594 return None 595 596 597def check_provider_warnings(provider_name: ModelProviderName): 598 """ 599 Validates that required configuration is present for a given provider. 600 601 Args: 602 provider_name: The provider to check 603 604 Raises: 605 ValueError: If required configuration keys are missing 606 """ 607 warning_check = provider_warnings.get(provider_name) 608 if warning_check is None: 609 return 610 for key in warning_check.required_config_keys: 611 if get_config_value(key) is None: 612 raise ValueError(warning_check.message) 613 614 615async def langchain_model_from( 616 name: str, provider_name: str | None = None 617) -> BaseChatModel: 618 """ 619 Creates a LangChain chat model instance for the specified model and provider. 620 621 Args: 622 name: The name of the model to instantiate 623 provider_name: Optional specific provider to use (defaults to first available) 624 625 Returns: 626 A configured LangChain chat model instance 627 628 Raises: 629 ValueError: If the model/provider combination is invalid or misconfigured 630 """ 631 if name not in ModelName.__members__: 632 raise ValueError(f"Invalid name: {name}") 633 634 # Select the model from built_in_models using the name 635 model = next(filter(lambda m: m.name == name, built_in_models)) 636 if model is None: 637 raise ValueError(f"Model {name} not found") 638 639 # If a provider is provided, select the provider from the model's provider_config 640 provider: KilnModelProvider | None = None 641 if model.providers is None or len(model.providers) == 0: 642 raise ValueError(f"Model {name} has no providers") 643 elif provider_name is None: 644 # TODO: priority order 645 provider = model.providers[0] 646 else: 647 provider = next( 648 filter(lambda p: p.name == provider_name, model.providers), None 649 ) 650 if provider is None: 651 raise ValueError(f"Provider {provider_name} not found for model {name}") 652 653 check_provider_warnings(provider.name) 654 655 if provider.name == ModelProviderName.openai: 656 api_key = Config.shared().open_ai_api_key 657 return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type] 658 elif provider.name == ModelProviderName.groq: 659 api_key = Config.shared().groq_api_key 660 if api_key is None: 661 raise ValueError( 662 "Attempted to use Groq without an API key set. " 663 "Get your API key from https://console.groq.com/keys" 664 ) 665 return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type] 666 elif provider.name == ModelProviderName.amazon_bedrock: 667 api_key = Config.shared().bedrock_access_key 668 secret_key = Config.shared().bedrock_secret_key 669 # langchain doesn't allow passing these, so ugly hack to set env vars 670 os.environ["AWS_ACCESS_KEY_ID"] = api_key 671 os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key 672 return ChatBedrockConverse( 673 **provider.provider_options, 674 ) 675 elif provider.name == ModelProviderName.ollama: 676 # Ollama model naming is pretty flexible. We try a few versions of the model name 677 potential_model_names = [] 678 if "model" in provider.provider_options: 679 potential_model_names.append(provider.provider_options["model"]) 680 if "model_aliases" in provider.provider_options: 681 potential_model_names.extend(provider.provider_options["model_aliases"]) 682 683 # Get the list of models Ollama supports 684 ollama_connection = await get_ollama_connection() 685 if ollama_connection is None: 686 raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.") 687 688 for model_name in potential_model_names: 689 if ollama_model_supported(ollama_connection, model_name): 690 return ChatOllama(model=model_name, base_url=ollama_base_url()) 691 692 raise ValueError(f"Model {name} not installed on Ollama") 693 elif provider.name == ModelProviderName.openrouter: 694 api_key = Config.shared().open_router_api_key 695 base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1" 696 return ChatOpenAI( 697 **provider.provider_options, 698 openai_api_key=api_key, # type: ignore[arg-type] 699 openai_api_base=base_url, # type: ignore[arg-type] 700 default_headers={ 701 "HTTP-Referer": "https://getkiln.ai/openrouter", 702 "X-Title": "KilnAI", 703 }, 704 ) 705 else: 706 raise ValueError(f"Invalid model or provider: {name} - {provider_name}") 707 708 709def ollama_base_url() -> str: 710 """ 711 Gets the base URL for Ollama API connections. 712 713 Returns: 714 The base URL to use for Ollama API calls, using environment variable if set 715 or falling back to localhost default 716 """ 717 env_base_url = os.getenv("OLLAMA_BASE_URL") 718 if env_base_url is not None: 719 return env_base_url 720 return "http://localhost:11434" 721 722 723async def ollama_online() -> bool: 724 """ 725 Checks if the Ollama service is available and responding. 726 727 Returns: 728 True if Ollama is available and responding, False otherwise 729 """ 730 try: 731 httpx.get(ollama_base_url() + "/api/tags") 732 except httpx.RequestError: 733 return False 734 return True 735 736 737class OllamaConnection(BaseModel): 738 message: str 739 models: List[str] 740 741 742# Parse the Ollama /api/tags response 743def parse_ollama_tags(tags: Any) -> OllamaConnection | None: 744 # Build a list of models we support for Ollama from the built-in model list 745 supported_ollama_models = [ 746 provider.provider_options["model"] 747 for model in built_in_models 748 for provider in model.providers 749 if provider.name == ModelProviderName.ollama 750 ] 751 # Append model_aliases to supported_ollama_models 752 supported_ollama_models.extend( 753 [ 754 alias 755 for model in built_in_models 756 for provider in model.providers 757 for alias in provider.provider_options.get("model_aliases", []) 758 ] 759 ) 760 761 if "models" in tags: 762 models = tags["models"] 763 if isinstance(models, list): 764 model_names = [model["model"] for model in models] 765 available_supported_models = [ 766 model 767 for model in model_names 768 if model in supported_ollama_models 769 or model in [f"{m}:latest" for m in supported_ollama_models] 770 ] 771 if available_supported_models: 772 return OllamaConnection( 773 message="Ollama connected", 774 models=available_supported_models, 775 ) 776 777 return OllamaConnection( 778 message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.", 779 models=[], 780 ) 781 782 783async def get_ollama_connection() -> OllamaConnection | None: 784 """ 785 Gets the connection status for Ollama. 786 """ 787 try: 788 tags = requests.get(ollama_base_url() + "/api/tags", timeout=5).json() 789 790 except Exception: 791 return None 792 793 return parse_ollama_tags(tags) 794 795 796def ollama_model_supported(conn: OllamaConnection, model_name: str) -> bool: 797 return model_name in conn.models or f"{model_name}:latest" in conn.models
26class ModelProviderName(str, Enum): 27 """ 28 Enumeration of supported AI model providers. 29 """ 30 31 openai = "openai" 32 groq = "groq" 33 amazon_bedrock = "amazon_bedrock" 34 ollama = "ollama" 35 openrouter = "openrouter"
Enumeration of supported AI model providers.
38class ModelFamily(str, Enum): 39 """ 40 Enumeration of supported model families/architectures. 41 """ 42 43 gpt = "gpt" 44 llama = "llama" 45 phi = "phi" 46 mistral = "mistral" 47 gemma = "gemma" 48 gemini = "gemini" 49 claude = "claude"
Enumeration of supported model families/architectures.
53class ModelName(str, Enum): 54 """ 55 Enumeration of specific model versions supported by the system. 56 Where models have instruct and raw versions, instruct is default and raw is specified. 57 """ 58 59 llama_3_1_8b = "llama_3_1_8b" 60 llama_3_1_70b = "llama_3_1_70b" 61 llama_3_1_405b = "llama_3_1_405b" 62 llama_3_2_3b = "llama_3_2_3b" 63 llama_3_2_11b = "llama_3_2_11b" 64 llama_3_2_90b = "llama_3_2_90b" 65 gpt_4o_mini = "gpt_4o_mini" 66 gpt_4o = "gpt_4o" 67 phi_3_5 = "phi_3_5" 68 mistral_large = "mistral_large" 69 mistral_nemo = "mistral_nemo" 70 gemma_2_2b = "gemma_2_2b" 71 gemma_2_9b = "gemma_2_9b" 72 gemma_2_27b = "gemma_2_27b" 73 claude_3_5_haiku = "claude_3_5_haiku" 74 claude_3_5_sonnet = "claude_3_5_sonnet" 75 gemini_1_5_flash = "gemini_1_5_flash" 76 gemini_1_5_flash_8b = "gemini_1_5_flash_8b" 77 gemini_1_5_pro = "gemini_1_5_pro" 78 nemotron_70b = "nemotron_70b"
Enumeration of specific model versions supported by the system. Where models have instruct and raw versions, instruct is default and raw is specified.
81class KilnModelProvider(BaseModel): 82 """ 83 Configuration for a specific model provider. 84 85 Attributes: 86 name: The provider's identifier 87 supports_structured_output: Whether the provider supports structured output formats 88 provider_options: Additional provider-specific configuration options 89 """ 90 91 name: ModelProviderName 92 supports_structured_output: bool = True 93 supports_data_gen: bool = True 94 provider_finetune_id: str | None = None 95 provider_options: Dict = {}
Configuration for a specific model provider.
Attributes: name: The provider's identifier supports_structured_output: Whether the provider supports structured output formats provider_options: Additional provider-specific configuration options
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Metadata about the fields defined on the model,
mapping of field names to [FieldInfo
][pydantic.fields.FieldInfo] objects.
This replaces Model.__fields__
from Pydantic V1.
98class KilnModel(BaseModel): 99 """ 100 Configuration for a specific AI model. 101 102 Attributes: 103 family: The model's architecture family 104 name: The model's identifier 105 friendly_name: Human-readable name for the model 106 providers: List of providers that offer this model 107 supports_structured_output: Whether the model supports structured output formats 108 """ 109 110 family: str 111 name: str 112 friendly_name: str 113 providers: List[KilnModelProvider] 114 supports_structured_output: bool = True
Configuration for a specific AI model.
Attributes: family: The model's architecture family name: The model's identifier friendly_name: Human-readable name for the model providers: List of providers that offer this model supports_structured_output: Whether the model supports structured output formats
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Metadata about the fields defined on the model,
mapping of field names to [FieldInfo
][pydantic.fields.FieldInfo] objects.
This replaces Model.__fields__
from Pydantic V1.
498def get_model_and_provider( 499 model_name: str, provider_name: str 500) -> tuple[KilnModel | None, KilnModelProvider | None]: 501 model = next(filter(lambda m: m.name == model_name, built_in_models), None) 502 if model is None: 503 return None, None 504 provider = next(filter(lambda p: p.name == provider_name, model.providers), None) 505 # all or nothing 506 if provider is None or model is None: 507 return None, None 508 return model, provider
511def provider_name_from_id(id: str) -> str: 512 """ 513 Converts a provider ID to its human-readable name. 514 515 Args: 516 id: The provider identifier string 517 518 Returns: 519 The human-readable name of the provider 520 521 Raises: 522 ValueError: If the provider ID is invalid or unhandled 523 """ 524 if id in ModelProviderName.__members__: 525 enum_id = ModelProviderName(id) 526 match enum_id: 527 case ModelProviderName.amazon_bedrock: 528 return "Amazon Bedrock" 529 case ModelProviderName.openrouter: 530 return "OpenRouter" 531 case ModelProviderName.groq: 532 return "Groq" 533 case ModelProviderName.ollama: 534 return "Ollama" 535 case ModelProviderName.openai: 536 return "OpenAI" 537 case _: 538 # triggers pyright warning if I miss a case 539 raise_exhaustive_error(enum_id) 540 541 return "Unknown provider: " + id
Converts a provider ID to its human-readable name.
Args: id: The provider identifier string
Returns: The human-readable name of the provider
Raises: ValueError: If the provider ID is invalid or unhandled
574async def provider_enabled(provider_name: ModelProviderName) -> bool: 575 if provider_name == ModelProviderName.ollama: 576 try: 577 tags = await get_ollama_connection() 578 return tags is not None and len(tags.models) > 0 579 except Exception: 580 return False 581 582 provider_warning = provider_warnings.get(provider_name) 583 if provider_warning is None: 584 return False 585 for required_key in provider_warning.required_config_keys: 586 if get_config_value(required_key) is None: 587 return False 588 return True
598def check_provider_warnings(provider_name: ModelProviderName): 599 """ 600 Validates that required configuration is present for a given provider. 601 602 Args: 603 provider_name: The provider to check 604 605 Raises: 606 ValueError: If required configuration keys are missing 607 """ 608 warning_check = provider_warnings.get(provider_name) 609 if warning_check is None: 610 return 611 for key in warning_check.required_config_keys: 612 if get_config_value(key) is None: 613 raise ValueError(warning_check.message)
Validates that required configuration is present for a given provider.
Args: provider_name: The provider to check
Raises: ValueError: If required configuration keys are missing
616async def langchain_model_from( 617 name: str, provider_name: str | None = None 618) -> BaseChatModel: 619 """ 620 Creates a LangChain chat model instance for the specified model and provider. 621 622 Args: 623 name: The name of the model to instantiate 624 provider_name: Optional specific provider to use (defaults to first available) 625 626 Returns: 627 A configured LangChain chat model instance 628 629 Raises: 630 ValueError: If the model/provider combination is invalid or misconfigured 631 """ 632 if name not in ModelName.__members__: 633 raise ValueError(f"Invalid name: {name}") 634 635 # Select the model from built_in_models using the name 636 model = next(filter(lambda m: m.name == name, built_in_models)) 637 if model is None: 638 raise ValueError(f"Model {name} not found") 639 640 # If a provider is provided, select the provider from the model's provider_config 641 provider: KilnModelProvider | None = None 642 if model.providers is None or len(model.providers) == 0: 643 raise ValueError(f"Model {name} has no providers") 644 elif provider_name is None: 645 # TODO: priority order 646 provider = model.providers[0] 647 else: 648 provider = next( 649 filter(lambda p: p.name == provider_name, model.providers), None 650 ) 651 if provider is None: 652 raise ValueError(f"Provider {provider_name} not found for model {name}") 653 654 check_provider_warnings(provider.name) 655 656 if provider.name == ModelProviderName.openai: 657 api_key = Config.shared().open_ai_api_key 658 return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type] 659 elif provider.name == ModelProviderName.groq: 660 api_key = Config.shared().groq_api_key 661 if api_key is None: 662 raise ValueError( 663 "Attempted to use Groq without an API key set. " 664 "Get your API key from https://console.groq.com/keys" 665 ) 666 return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type] 667 elif provider.name == ModelProviderName.amazon_bedrock: 668 api_key = Config.shared().bedrock_access_key 669 secret_key = Config.shared().bedrock_secret_key 670 # langchain doesn't allow passing these, so ugly hack to set env vars 671 os.environ["AWS_ACCESS_KEY_ID"] = api_key 672 os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key 673 return ChatBedrockConverse( 674 **provider.provider_options, 675 ) 676 elif provider.name == ModelProviderName.ollama: 677 # Ollama model naming is pretty flexible. We try a few versions of the model name 678 potential_model_names = [] 679 if "model" in provider.provider_options: 680 potential_model_names.append(provider.provider_options["model"]) 681 if "model_aliases" in provider.provider_options: 682 potential_model_names.extend(provider.provider_options["model_aliases"]) 683 684 # Get the list of models Ollama supports 685 ollama_connection = await get_ollama_connection() 686 if ollama_connection is None: 687 raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.") 688 689 for model_name in potential_model_names: 690 if ollama_model_supported(ollama_connection, model_name): 691 return ChatOllama(model=model_name, base_url=ollama_base_url()) 692 693 raise ValueError(f"Model {name} not installed on Ollama") 694 elif provider.name == ModelProviderName.openrouter: 695 api_key = Config.shared().open_router_api_key 696 base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1" 697 return ChatOpenAI( 698 **provider.provider_options, 699 openai_api_key=api_key, # type: ignore[arg-type] 700 openai_api_base=base_url, # type: ignore[arg-type] 701 default_headers={ 702 "HTTP-Referer": "https://getkiln.ai/openrouter", 703 "X-Title": "KilnAI", 704 }, 705 ) 706 else: 707 raise ValueError(f"Invalid model or provider: {name} - {provider_name}")
Creates a LangChain chat model instance for the specified model and provider.
Args: name: The name of the model to instantiate provider_name: Optional specific provider to use (defaults to first available)
Returns: A configured LangChain chat model instance
Raises: ValueError: If the model/provider combination is invalid or misconfigured
710def ollama_base_url() -> str: 711 """ 712 Gets the base URL for Ollama API connections. 713 714 Returns: 715 The base URL to use for Ollama API calls, using environment variable if set 716 or falling back to localhost default 717 """ 718 env_base_url = os.getenv("OLLAMA_BASE_URL") 719 if env_base_url is not None: 720 return env_base_url 721 return "http://localhost:11434"
Gets the base URL for Ollama API connections.
Returns: The base URL to use for Ollama API calls, using environment variable if set or falling back to localhost default
724async def ollama_online() -> bool: 725 """ 726 Checks if the Ollama service is available and responding. 727 728 Returns: 729 True if Ollama is available and responding, False otherwise 730 """ 731 try: 732 httpx.get(ollama_base_url() + "/api/tags") 733 except httpx.RequestError: 734 return False 735 return True
Checks if the Ollama service is available and responding.
Returns: True if Ollama is available and responding, False otherwise
Usage docs: https://docs.pydantic.dev/2.9/concepts/models/
A base class for creating Pydantic models.
Attributes:
__class_vars__: The names of the class variables defined on the model.
__private_attributes__: Metadata about the private attributes of the model.
__signature__: The synthesized __init__
[Signature
][inspect.Signature] of the model.
__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The core schema of the model.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
__args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a [`RootModel`][pydantic.root_model.RootModel].
__pydantic_serializer__: The `pydantic-core` `SchemaSerializer` used to dump instances of the model.
__pydantic_validator__: The `pydantic-core` `SchemaValidator` used to validate instances of the model.
__pydantic_extra__: A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra]
is set to `'allow'`.
__pydantic_fields_set__: The names of fields explicitly set during instantiation.
__pydantic_private__: Values of private attributes set on the model instance.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Metadata about the fields defined on the model,
mapping of field names to [FieldInfo
][pydantic.fields.FieldInfo] objects.
This replaces Model.__fields__
from Pydantic V1.
784async def get_ollama_connection() -> OllamaConnection | None: 785 """ 786 Gets the connection status for Ollama. 787 """ 788 try: 789 tags = requests.get(ollama_base_url() + "/api/tags", timeout=5).json() 790 791 except Exception: 792 return None 793 794 return parse_ollama_tags(tags)
Gets the connection status for Ollama.