kiln_ai.adapters.ml_model_list
1import os 2from dataclasses import dataclass 3from enum import Enum 4from os import getenv 5from typing import Dict, List, NoReturn 6 7import httpx 8from langchain_aws import ChatBedrockConverse 9from langchain_core.language_models.chat_models import BaseChatModel 10from langchain_groq import ChatGroq 11from langchain_ollama import ChatOllama 12from langchain_openai import ChatOpenAI 13from pydantic import BaseModel 14 15from ..utils.config import Config 16 17 18class ModelProviderName(str, Enum): 19 openai = "openai" 20 groq = "groq" 21 amazon_bedrock = "amazon_bedrock" 22 ollama = "ollama" 23 openrouter = "openrouter" 24 25 26class ModelFamily(str, Enum): 27 gpt = "gpt" 28 llama = "llama" 29 phi = "phi" 30 mistral = "mistral" 31 gemma = "gemma" 32 33 34# Where models have instruct and raw versions, instruct is default and raw is specified 35class ModelName(str, Enum): 36 llama_3_1_8b = "llama_3_1_8b" 37 llama_3_1_70b = "llama_3_1_70b" 38 llama_3_1_405b = "llama_3_1_405b" 39 gpt_4o_mini = "gpt_4o_mini" 40 gpt_4o = "gpt_4o" 41 phi_3_5 = "phi_3_5" 42 mistral_large = "mistral_large" 43 mistral_nemo = "mistral_nemo" 44 gemma_2_2b = "gemma_2_2b" 45 gemma_2_9b = "gemma_2_9b" 46 gemma_2_27b = "gemma_2_27b" 47 48 49class KilnModelProvider(BaseModel): 50 name: ModelProviderName 51 # Allow overriding the model level setting 52 supports_structured_output: bool = True 53 provider_options: Dict = {} 54 55 56class KilnModel(BaseModel): 57 family: str 58 name: str 59 friendly_name: str 60 providers: List[KilnModelProvider] 61 supports_structured_output: bool = True 62 63 64built_in_models: List[KilnModel] = [ 65 # GPT 4o Mini 66 KilnModel( 67 family=ModelFamily.gpt, 68 name=ModelName.gpt_4o_mini, 69 friendly_name="GPT 4o Mini", 70 providers=[ 71 KilnModelProvider( 72 name=ModelProviderName.openai, 73 provider_options={"model": "gpt-4o-mini"}, 74 ), 75 KilnModelProvider( 76 name=ModelProviderName.openrouter, 77 provider_options={"model": "openai/gpt-4o-mini"}, 78 ), 79 ], 80 ), 81 # GPT 4o 82 KilnModel( 83 family=ModelFamily.gpt, 84 name=ModelName.gpt_4o, 85 friendly_name="GPT 4o", 86 providers=[ 87 KilnModelProvider( 88 name=ModelProviderName.openai, 89 provider_options={"model": "gpt-4o"}, 90 ), 91 KilnModelProvider( 92 name=ModelProviderName.openrouter, 93 provider_options={"model": "openai/gpt-4o-2024-08-06"}, 94 ), 95 ], 96 ), 97 # Llama 3.1-8b 98 KilnModel( 99 family=ModelFamily.llama, 100 name=ModelName.llama_3_1_8b, 101 friendly_name="Llama 3.1 8B", 102 providers=[ 103 KilnModelProvider( 104 name=ModelProviderName.groq, 105 provider_options={"model": "llama-3.1-8b-instant"}, 106 ), 107 KilnModelProvider( 108 name=ModelProviderName.amazon_bedrock, 109 provider_options={ 110 "model": "meta.llama3-1-8b-instruct-v1:0", 111 "region_name": "us-west-2", # Llama 3.1 only in west-2 112 }, 113 ), 114 KilnModelProvider( 115 name=ModelProviderName.ollama, 116 provider_options={"model": "llama3.1"}, # 8b is default 117 ), 118 KilnModelProvider( 119 name=ModelProviderName.openrouter, 120 provider_options={"model": "meta-llama/llama-3.1-8b-instruct"}, 121 ), 122 ], 123 ), 124 # Llama 3.1 70b 125 KilnModel( 126 family=ModelFamily.llama, 127 name=ModelName.llama_3_1_70b, 128 friendly_name="Llama 3.1 70B", 129 providers=[ 130 KilnModelProvider( 131 name=ModelProviderName.groq, 132 provider_options={"model": "llama-3.1-70b-versatile"}, 133 ), 134 KilnModelProvider( 135 name=ModelProviderName.amazon_bedrock, 136 # TODO: this should work but a bug in the bedrock response schema 137 supports_structured_output=False, 138 provider_options={ 139 "model": "meta.llama3-1-70b-instruct-v1:0", 140 "region_name": "us-west-2", # Llama 3.1 only in west-2 141 }, 142 ), 143 KilnModelProvider( 144 name=ModelProviderName.openrouter, 145 provider_options={"model": "meta-llama/llama-3.1-70b-instruct"}, 146 ), 147 # TODO: enable once tests update to check if model is available 148 # KilnModelProvider( 149 # provider=ModelProviders.ollama, 150 # provider_options={"model": "llama3.1:70b"}, 151 # ), 152 ], 153 ), 154 # Llama 3.1 405b 155 KilnModel( 156 family=ModelFamily.llama, 157 name=ModelName.llama_3_1_405b, 158 friendly_name="Llama 3.1 405B", 159 providers=[ 160 # TODO: bring back when groq does: https://console.groq.com/docs/models 161 # KilnModelProvider( 162 # name=ModelProviderName.groq, 163 # provider_options={"model": "llama-3.1-405b-instruct-v1:0"}, 164 # ), 165 KilnModelProvider( 166 name=ModelProviderName.amazon_bedrock, 167 provider_options={ 168 "model": "meta.llama3-1-405b-instruct-v1:0", 169 "region_name": "us-west-2", # Llama 3.1 only in west-2 170 }, 171 ), 172 # TODO: enable once tests update to check if model is available 173 # KilnModelProvider( 174 # name=ModelProviderName.ollama, 175 # provider_options={"model": "llama3.1:405b"}, 176 # ), 177 KilnModelProvider( 178 name=ModelProviderName.openrouter, 179 provider_options={"model": "meta-llama/llama-3.1-405b-instruct"}, 180 ), 181 ], 182 ), 183 # Mistral Nemo 184 KilnModel( 185 family=ModelFamily.mistral, 186 name=ModelName.mistral_nemo, 187 friendly_name="Mistral Nemo", 188 providers=[ 189 KilnModelProvider( 190 name=ModelProviderName.openrouter, 191 provider_options={"model": "mistralai/mistral-nemo"}, 192 ), 193 ], 194 ), 195 # Mistral Large 196 KilnModel( 197 family=ModelFamily.mistral, 198 name=ModelName.mistral_large, 199 friendly_name="Mistral Large", 200 providers=[ 201 KilnModelProvider( 202 name=ModelProviderName.amazon_bedrock, 203 provider_options={ 204 "model": "mistral.mistral-large-2407-v1:0", 205 "region_name": "us-west-2", # only in west-2 206 }, 207 ), 208 KilnModelProvider( 209 name=ModelProviderName.openrouter, 210 provider_options={"model": "mistralai/mistral-large"}, 211 ), 212 # TODO: enable once tests update to check if model is available 213 # KilnModelProvider( 214 # provider=ModelProviders.ollama, 215 # provider_options={"model": "mistral-large"}, 216 # ), 217 ], 218 ), 219 # Phi 3.5 220 KilnModel( 221 family=ModelFamily.phi, 222 name=ModelName.phi_3_5, 223 friendly_name="Phi 3.5", 224 supports_structured_output=False, 225 providers=[ 226 KilnModelProvider( 227 name=ModelProviderName.ollama, 228 provider_options={"model": "phi3.5"}, 229 ), 230 KilnModelProvider( 231 name=ModelProviderName.openrouter, 232 provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"}, 233 ), 234 ], 235 ), 236 # Gemma 2 2.6b 237 KilnModel( 238 family=ModelFamily.gemma, 239 name=ModelName.gemma_2_2b, 240 friendly_name="Gemma 2 2B", 241 supports_structured_output=False, 242 providers=[ 243 KilnModelProvider( 244 name=ModelProviderName.ollama, 245 provider_options={ 246 "model": "gemma2:2b", 247 }, 248 ), 249 ], 250 ), 251 # Gemma 2 9b 252 KilnModel( 253 family=ModelFamily.gemma, 254 name=ModelName.gemma_2_9b, 255 friendly_name="Gemma 2 9B", 256 supports_structured_output=False, 257 providers=[ 258 # TODO: enable once tests update to check if model is available 259 # KilnModelProvider( 260 # name=ModelProviderName.ollama, 261 # provider_options={ 262 # "model": "gemma2:9b", 263 # }, 264 # ), 265 KilnModelProvider( 266 name=ModelProviderName.openrouter, 267 provider_options={"model": "google/gemma-2-9b-it"}, 268 ), 269 ], 270 ), 271 # Gemma 2 27b 272 KilnModel( 273 family=ModelFamily.gemma, 274 name=ModelName.gemma_2_27b, 275 friendly_name="Gemma 2 27B", 276 supports_structured_output=False, 277 providers=[ 278 # TODO: enable once tests update to check if model is available 279 # KilnModelProvider( 280 # name=ModelProviderName.ollama, 281 # provider_options={ 282 # "model": "gemma2:27b", 283 # }, 284 # ), 285 KilnModelProvider( 286 name=ModelProviderName.openrouter, 287 provider_options={"model": "google/gemma-2-27b-it"}, 288 ), 289 ], 290 ), 291] 292 293 294def provider_name_from_id(id: str) -> str: 295 if id in ModelProviderName.__members__: 296 enum_id = ModelProviderName(id) 297 match enum_id: 298 case ModelProviderName.amazon_bedrock: 299 return "Amazon Bedrock" 300 case ModelProviderName.openrouter: 301 return "OpenRouter" 302 case ModelProviderName.groq: 303 return "Groq" 304 case ModelProviderName.ollama: 305 return "Ollama" 306 case ModelProviderName.openai: 307 return "OpenAI" 308 case _: 309 # triggers pyright warning if I miss a case 310 raise_exhaustive_error(enum_id) 311 312 return "Unknown provider: " + id 313 314 315def raise_exhaustive_error(value: NoReturn) -> NoReturn: 316 raise ValueError(f"Unhandled enum value: {value}") 317 318 319@dataclass 320class ModelProviderWarning: 321 required_config_keys: List[str] 322 message: str 323 324 325provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = { 326 ModelProviderName.amazon_bedrock: ModelProviderWarning( 327 required_config_keys=["bedrock_access_key", "bedrock_secret_key"], 328 message="Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview", 329 ), 330 ModelProviderName.openrouter: ModelProviderWarning( 331 required_config_keys=["open_router_api_key"], 332 message="Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys", 333 ), 334 ModelProviderName.groq: ModelProviderWarning( 335 required_config_keys=["groq_api_key"], 336 message="Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys", 337 ), 338 ModelProviderName.openai: ModelProviderWarning( 339 required_config_keys=["open_ai_api_key"], 340 message="Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys", 341 ), 342} 343 344 345def get_config_value(key: str): 346 try: 347 return Config.shared().__getattr__(key) 348 except AttributeError: 349 return None 350 351 352def check_provider_warnings(provider_name: ModelProviderName): 353 warning_check = provider_warnings.get(provider_name) 354 if warning_check is None: 355 return 356 for key in warning_check.required_config_keys: 357 if get_config_value(key) is None: 358 raise ValueError(warning_check.message) 359 360 361def langchain_model_from(name: str, provider_name: str | None = None) -> BaseChatModel: 362 if name not in ModelName.__members__: 363 raise ValueError(f"Invalid name: {name}") 364 365 # Select the model from built_in_models using the name 366 model = next(filter(lambda m: m.name == name, built_in_models)) 367 if model is None: 368 raise ValueError(f"Model {name} not found") 369 370 # If a provider is provided, select the provider from the model's provider_config 371 provider: KilnModelProvider | None = None 372 if model.providers is None or len(model.providers) == 0: 373 raise ValueError(f"Model {name} has no providers") 374 elif provider_name is None: 375 # TODO: priority order 376 provider = model.providers[0] 377 else: 378 provider = next( 379 filter(lambda p: p.name == provider_name, model.providers), None 380 ) 381 if provider is None: 382 raise ValueError(f"Provider {provider_name} not found for model {name}") 383 384 check_provider_warnings(provider.name) 385 386 if provider.name == ModelProviderName.openai: 387 api_key = Config.shared().open_ai_api_key 388 return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type] 389 elif provider.name == ModelProviderName.groq: 390 api_key = Config.shared().groq_api_key 391 if api_key is None: 392 raise ValueError( 393 "Attempted to use Groq without an API key set. " 394 "Get your API key from https://console.groq.com/keys" 395 ) 396 return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type] 397 elif provider.name == ModelProviderName.amazon_bedrock: 398 api_key = Config.shared().bedrock_access_key 399 secret_key = Config.shared().bedrock_secret_key 400 # langchain doesn't allow passing these, so ugly hack to set env vars 401 os.environ["AWS_ACCESS_KEY_ID"] = api_key 402 os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key 403 return ChatBedrockConverse( 404 **provider.provider_options, 405 ) 406 elif provider.name == ModelProviderName.ollama: 407 return ChatOllama(**provider.provider_options, base_url=ollama_base_url()) 408 elif provider.name == ModelProviderName.openrouter: 409 api_key = Config.shared().open_router_api_key 410 base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1" 411 return ChatOpenAI( 412 **provider.provider_options, 413 openai_api_key=api_key, # type: ignore[arg-type] 414 openai_api_base=base_url, # type: ignore[arg-type] 415 default_headers={ 416 "HTTP-Referer": "https://kiln-ai.com/openrouter", 417 "X-Title": "KilnAI", 418 }, 419 ) 420 else: 421 raise ValueError(f"Invalid model or provider: {name} - {provider_name}") 422 423 424def ollama_base_url(): 425 env_base_url = os.getenv("OLLAMA_BASE_URL") 426 if env_base_url is not None: 427 return env_base_url 428 return "http://localhost:11434" 429 430 431async def ollama_online(): 432 try: 433 httpx.get(ollama_base_url() + "/api/tags") 434 except httpx.RequestError: 435 return False 436 return True
19class ModelProviderName(str, Enum): 20 openai = "openai" 21 groq = "groq" 22 amazon_bedrock = "amazon_bedrock" 23 ollama = "ollama" 24 openrouter = "openrouter"
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
27class ModelFamily(str, Enum): 28 gpt = "gpt" 29 llama = "llama" 30 phi = "phi" 31 mistral = "mistral" 32 gemma = "gemma"
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
36class ModelName(str, Enum): 37 llama_3_1_8b = "llama_3_1_8b" 38 llama_3_1_70b = "llama_3_1_70b" 39 llama_3_1_405b = "llama_3_1_405b" 40 gpt_4o_mini = "gpt_4o_mini" 41 gpt_4o = "gpt_4o" 42 phi_3_5 = "phi_3_5" 43 mistral_large = "mistral_large" 44 mistral_nemo = "mistral_nemo" 45 gemma_2_2b = "gemma_2_2b" 46 gemma_2_9b = "gemma_2_9b" 47 gemma_2_27b = "gemma_2_27b"
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
50class KilnModelProvider(BaseModel): 51 name: ModelProviderName 52 # Allow overriding the model level setting 53 supports_structured_output: bool = True 54 provider_options: Dict = {}
Usage docs: https://docs.pydantic.dev/2.8/concepts/models/
A base class for creating Pydantic models.
Attributes: __class_vars__: The names of classvars defined on the model. __private_attributes__: Metadata about the private attributes of the model. __signature__: The signature for instantiating the model.
__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
__args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a `RootModel`.
__pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
__pydantic_extra__: An instance attribute with the values of extra fields from validation when
`model_config['extra'] == 'allow'`.
__pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
__pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
57class KilnModel(BaseModel): 58 family: str 59 name: str 60 friendly_name: str 61 providers: List[KilnModelProvider] 62 supports_structured_output: bool = True
Usage docs: https://docs.pydantic.dev/2.8/concepts/models/
A base class for creating Pydantic models.
Attributes: __class_vars__: The names of classvars defined on the model. __private_attributes__: Metadata about the private attributes of the model. __signature__: The signature for instantiating the model.
__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
__args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a `RootModel`.
__pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
__pydantic_extra__: An instance attribute with the values of extra fields from validation when
`model_config['extra'] == 'allow'`.
__pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
__pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
295def provider_name_from_id(id: str) -> str: 296 if id in ModelProviderName.__members__: 297 enum_id = ModelProviderName(id) 298 match enum_id: 299 case ModelProviderName.amazon_bedrock: 300 return "Amazon Bedrock" 301 case ModelProviderName.openrouter: 302 return "OpenRouter" 303 case ModelProviderName.groq: 304 return "Groq" 305 case ModelProviderName.ollama: 306 return "Ollama" 307 case ModelProviderName.openai: 308 return "OpenAI" 309 case _: 310 # triggers pyright warning if I miss a case 311 raise_exhaustive_error(enum_id) 312 313 return "Unknown provider: " + id
362def langchain_model_from(name: str, provider_name: str | None = None) -> BaseChatModel: 363 if name not in ModelName.__members__: 364 raise ValueError(f"Invalid name: {name}") 365 366 # Select the model from built_in_models using the name 367 model = next(filter(lambda m: m.name == name, built_in_models)) 368 if model is None: 369 raise ValueError(f"Model {name} not found") 370 371 # If a provider is provided, select the provider from the model's provider_config 372 provider: KilnModelProvider | None = None 373 if model.providers is None or len(model.providers) == 0: 374 raise ValueError(f"Model {name} has no providers") 375 elif provider_name is None: 376 # TODO: priority order 377 provider = model.providers[0] 378 else: 379 provider = next( 380 filter(lambda p: p.name == provider_name, model.providers), None 381 ) 382 if provider is None: 383 raise ValueError(f"Provider {provider_name} not found for model {name}") 384 385 check_provider_warnings(provider.name) 386 387 if provider.name == ModelProviderName.openai: 388 api_key = Config.shared().open_ai_api_key 389 return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type] 390 elif provider.name == ModelProviderName.groq: 391 api_key = Config.shared().groq_api_key 392 if api_key is None: 393 raise ValueError( 394 "Attempted to use Groq without an API key set. " 395 "Get your API key from https://console.groq.com/keys" 396 ) 397 return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type] 398 elif provider.name == ModelProviderName.amazon_bedrock: 399 api_key = Config.shared().bedrock_access_key 400 secret_key = Config.shared().bedrock_secret_key 401 # langchain doesn't allow passing these, so ugly hack to set env vars 402 os.environ["AWS_ACCESS_KEY_ID"] = api_key 403 os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key 404 return ChatBedrockConverse( 405 **provider.provider_options, 406 ) 407 elif provider.name == ModelProviderName.ollama: 408 return ChatOllama(**provider.provider_options, base_url=ollama_base_url()) 409 elif provider.name == ModelProviderName.openrouter: 410 api_key = Config.shared().open_router_api_key 411 base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1" 412 return ChatOpenAI( 413 **provider.provider_options, 414 openai_api_key=api_key, # type: ignore[arg-type] 415 openai_api_base=base_url, # type: ignore[arg-type] 416 default_headers={ 417 "HTTP-Referer": "https://kiln-ai.com/openrouter", 418 "X-Title": "KilnAI", 419 }, 420 ) 421 else: 422 raise ValueError(f"Invalid model or provider: {name} - {provider_name}")