kiln_ai.adapters.ml_model_list

  1import os
  2from dataclasses import dataclass
  3from enum import Enum
  4from os import getenv
  5from typing import Dict, List, NoReturn
  6
  7import httpx
  8from langchain_aws import ChatBedrockConverse
  9from langchain_core.language_models.chat_models import BaseChatModel
 10from langchain_groq import ChatGroq
 11from langchain_ollama import ChatOllama
 12from langchain_openai import ChatOpenAI
 13from pydantic import BaseModel
 14
 15from ..utils.config import Config
 16
 17"""
 18Provides model configuration and management for various LLM providers and models.
 19This module handles the integration with different AI model providers and their respective models,
 20including configuration, validation, and instantiation of language models.
 21"""
 22
 23
 24class ModelProviderName(str, Enum):
 25    """
 26    Enumeration of supported AI model providers.
 27    """
 28
 29    openai = "openai"
 30    groq = "groq"
 31    amazon_bedrock = "amazon_bedrock"
 32    ollama = "ollama"
 33    openrouter = "openrouter"
 34
 35
 36class ModelFamily(str, Enum):
 37    """
 38    Enumeration of supported model families/architectures.
 39    """
 40
 41    gpt = "gpt"
 42    llama = "llama"
 43    phi = "phi"
 44    mistral = "mistral"
 45    gemma = "gemma"
 46
 47
 48# Where models have instruct and raw versions, instruct is default and raw is specified
 49class ModelName(str, Enum):
 50    """
 51    Enumeration of specific model versions supported by the system.
 52    Where models have instruct and raw versions, instruct is default and raw is specified.
 53    """
 54
 55    llama_3_1_8b = "llama_3_1_8b"
 56    llama_3_1_70b = "llama_3_1_70b"
 57    llama_3_1_405b = "llama_3_1_405b"
 58    gpt_4o_mini = "gpt_4o_mini"
 59    gpt_4o = "gpt_4o"
 60    phi_3_5 = "phi_3_5"
 61    mistral_large = "mistral_large"
 62    mistral_nemo = "mistral_nemo"
 63    gemma_2_2b = "gemma_2_2b"
 64    gemma_2_9b = "gemma_2_9b"
 65    gemma_2_27b = "gemma_2_27b"
 66
 67
 68class KilnModelProvider(BaseModel):
 69    """
 70    Configuration for a specific model provider.
 71
 72    Attributes:
 73        name: The provider's identifier
 74        supports_structured_output: Whether the provider supports structured output formats
 75        provider_options: Additional provider-specific configuration options
 76    """
 77
 78    name: ModelProviderName
 79    supports_structured_output: bool = True
 80    provider_options: Dict = {}
 81
 82
 83class KilnModel(BaseModel):
 84    """
 85    Configuration for a specific AI model.
 86
 87    Attributes:
 88        family: The model's architecture family
 89        name: The model's identifier
 90        friendly_name: Human-readable name for the model
 91        providers: List of providers that offer this model
 92        supports_structured_output: Whether the model supports structured output formats
 93    """
 94
 95    family: str
 96    name: str
 97    friendly_name: str
 98    providers: List[KilnModelProvider]
 99    supports_structured_output: bool = True
100
101
102built_in_models: List[KilnModel] = [
103    # GPT 4o Mini
104    KilnModel(
105        family=ModelFamily.gpt,
106        name=ModelName.gpt_4o_mini,
107        friendly_name="GPT 4o Mini",
108        providers=[
109            KilnModelProvider(
110                name=ModelProviderName.openai,
111                provider_options={"model": "gpt-4o-mini"},
112            ),
113            KilnModelProvider(
114                name=ModelProviderName.openrouter,
115                provider_options={"model": "openai/gpt-4o-mini"},
116            ),
117        ],
118    ),
119    # GPT 4o
120    KilnModel(
121        family=ModelFamily.gpt,
122        name=ModelName.gpt_4o,
123        friendly_name="GPT 4o",
124        providers=[
125            KilnModelProvider(
126                name=ModelProviderName.openai,
127                provider_options={"model": "gpt-4o"},
128            ),
129            KilnModelProvider(
130                name=ModelProviderName.openrouter,
131                provider_options={"model": "openai/gpt-4o-2024-08-06"},
132            ),
133        ],
134    ),
135    # Llama 3.1-8b
136    KilnModel(
137        family=ModelFamily.llama,
138        name=ModelName.llama_3_1_8b,
139        friendly_name="Llama 3.1 8B",
140        providers=[
141            KilnModelProvider(
142                name=ModelProviderName.groq,
143                provider_options={"model": "llama-3.1-8b-instant"},
144            ),
145            KilnModelProvider(
146                name=ModelProviderName.amazon_bedrock,
147                provider_options={
148                    "model": "meta.llama3-1-8b-instruct-v1:0",
149                    "region_name": "us-west-2",  # Llama 3.1 only in west-2
150                },
151            ),
152            KilnModelProvider(
153                name=ModelProviderName.ollama,
154                provider_options={"model": "llama3.1"},  # 8b is default
155            ),
156            KilnModelProvider(
157                name=ModelProviderName.openrouter,
158                provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
159            ),
160        ],
161    ),
162    # Llama 3.1 70b
163    KilnModel(
164        family=ModelFamily.llama,
165        name=ModelName.llama_3_1_70b,
166        friendly_name="Llama 3.1 70B",
167        providers=[
168            KilnModelProvider(
169                name=ModelProviderName.groq,
170                provider_options={"model": "llama-3.1-70b-versatile"},
171            ),
172            KilnModelProvider(
173                name=ModelProviderName.amazon_bedrock,
174                # TODO: this should work but a bug in the bedrock response schema
175                supports_structured_output=False,
176                provider_options={
177                    "model": "meta.llama3-1-70b-instruct-v1:0",
178                    "region_name": "us-west-2",  # Llama 3.1 only in west-2
179                },
180            ),
181            KilnModelProvider(
182                name=ModelProviderName.openrouter,
183                provider_options={"model": "meta-llama/llama-3.1-70b-instruct"},
184            ),
185            # TODO: enable once tests update to check if model is available
186            # KilnModelProvider(
187            #     provider=ModelProviders.ollama,
188            #     provider_options={"model": "llama3.1:70b"},
189            # ),
190        ],
191    ),
192    # Llama 3.1 405b
193    KilnModel(
194        family=ModelFamily.llama,
195        name=ModelName.llama_3_1_405b,
196        friendly_name="Llama 3.1 405B",
197        providers=[
198            # TODO: bring back when groq does: https://console.groq.com/docs/models
199            # KilnModelProvider(
200            #     name=ModelProviderName.groq,
201            #     provider_options={"model": "llama-3.1-405b-instruct-v1:0"},
202            # ),
203            KilnModelProvider(
204                name=ModelProviderName.amazon_bedrock,
205                provider_options={
206                    "model": "meta.llama3-1-405b-instruct-v1:0",
207                    "region_name": "us-west-2",  # Llama 3.1 only in west-2
208                },
209            ),
210            # TODO: enable once tests update to check if model is available
211            # KilnModelProvider(
212            #     name=ModelProviderName.ollama,
213            #     provider_options={"model": "llama3.1:405b"},
214            # ),
215            KilnModelProvider(
216                name=ModelProviderName.openrouter,
217                provider_options={"model": "meta-llama/llama-3.1-405b-instruct"},
218            ),
219        ],
220    ),
221    # Mistral Nemo
222    KilnModel(
223        family=ModelFamily.mistral,
224        name=ModelName.mistral_nemo,
225        friendly_name="Mistral Nemo",
226        providers=[
227            KilnModelProvider(
228                name=ModelProviderName.openrouter,
229                provider_options={"model": "mistralai/mistral-nemo"},
230            ),
231        ],
232    ),
233    # Mistral Large
234    KilnModel(
235        family=ModelFamily.mistral,
236        name=ModelName.mistral_large,
237        friendly_name="Mistral Large",
238        providers=[
239            KilnModelProvider(
240                name=ModelProviderName.amazon_bedrock,
241                provider_options={
242                    "model": "mistral.mistral-large-2407-v1:0",
243                    "region_name": "us-west-2",  # only in west-2
244                },
245            ),
246            KilnModelProvider(
247                name=ModelProviderName.openrouter,
248                provider_options={"model": "mistralai/mistral-large"},
249            ),
250            # TODO: enable once tests update to check if model is available
251            # KilnModelProvider(
252            #     provider=ModelProviders.ollama,
253            #     provider_options={"model": "mistral-large"},
254            # ),
255        ],
256    ),
257    # Phi 3.5
258    KilnModel(
259        family=ModelFamily.phi,
260        name=ModelName.phi_3_5,
261        friendly_name="Phi 3.5",
262        supports_structured_output=False,
263        providers=[
264            KilnModelProvider(
265                name=ModelProviderName.ollama,
266                provider_options={"model": "phi3.5"},
267            ),
268            KilnModelProvider(
269                name=ModelProviderName.openrouter,
270                provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"},
271            ),
272        ],
273    ),
274    # Gemma 2 2.6b
275    KilnModel(
276        family=ModelFamily.gemma,
277        name=ModelName.gemma_2_2b,
278        friendly_name="Gemma 2 2B",
279        supports_structured_output=False,
280        providers=[
281            KilnModelProvider(
282                name=ModelProviderName.ollama,
283                provider_options={
284                    "model": "gemma2:2b",
285                },
286            ),
287        ],
288    ),
289    # Gemma 2 9b
290    KilnModel(
291        family=ModelFamily.gemma,
292        name=ModelName.gemma_2_9b,
293        friendly_name="Gemma 2 9B",
294        supports_structured_output=False,
295        providers=[
296            # TODO: enable once tests update to check if model is available
297            # KilnModelProvider(
298            #     name=ModelProviderName.ollama,
299            #     provider_options={
300            #         "model": "gemma2:9b",
301            #     },
302            # ),
303            KilnModelProvider(
304                name=ModelProviderName.openrouter,
305                provider_options={"model": "google/gemma-2-9b-it"},
306            ),
307        ],
308    ),
309    # Gemma 2 27b
310    KilnModel(
311        family=ModelFamily.gemma,
312        name=ModelName.gemma_2_27b,
313        friendly_name="Gemma 2 27B",
314        supports_structured_output=False,
315        providers=[
316            # TODO: enable once tests update to check if model is available
317            # KilnModelProvider(
318            #     name=ModelProviderName.ollama,
319            #     provider_options={
320            #         "model": "gemma2:27b",
321            #     },
322            # ),
323            KilnModelProvider(
324                name=ModelProviderName.openrouter,
325                provider_options={"model": "google/gemma-2-27b-it"},
326            ),
327        ],
328    ),
329]
330
331
332def provider_name_from_id(id: str) -> str:
333    """
334    Converts a provider ID to its human-readable name.
335
336    Args:
337        id: The provider identifier string
338
339    Returns:
340        The human-readable name of the provider
341
342    Raises:
343        ValueError: If the provider ID is invalid or unhandled
344    """
345    if id in ModelProviderName.__members__:
346        enum_id = ModelProviderName(id)
347        match enum_id:
348            case ModelProviderName.amazon_bedrock:
349                return "Amazon Bedrock"
350            case ModelProviderName.openrouter:
351                return "OpenRouter"
352            case ModelProviderName.groq:
353                return "Groq"
354            case ModelProviderName.ollama:
355                return "Ollama"
356            case ModelProviderName.openai:
357                return "OpenAI"
358            case _:
359                # triggers pyright warning if I miss a case
360                raise_exhaustive_error(enum_id)
361
362    return "Unknown provider: " + id
363
364
365def raise_exhaustive_error(value: NoReturn) -> NoReturn:
366    raise ValueError(f"Unhandled enum value: {value}")
367
368
369@dataclass
370class ModelProviderWarning:
371    required_config_keys: List[str]
372    message: str
373
374
375provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
376    ModelProviderName.amazon_bedrock: ModelProviderWarning(
377        required_config_keys=["bedrock_access_key", "bedrock_secret_key"],
378        message="Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview",
379    ),
380    ModelProviderName.openrouter: ModelProviderWarning(
381        required_config_keys=["open_router_api_key"],
382        message="Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys",
383    ),
384    ModelProviderName.groq: ModelProviderWarning(
385        required_config_keys=["groq_api_key"],
386        message="Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys",
387    ),
388    ModelProviderName.openai: ModelProviderWarning(
389        required_config_keys=["open_ai_api_key"],
390        message="Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys",
391    ),
392}
393
394
395def get_config_value(key: str):
396    try:
397        return Config.shared().__getattr__(key)
398    except AttributeError:
399        return None
400
401
402def check_provider_warnings(provider_name: ModelProviderName):
403    """
404    Validates that required configuration is present for a given provider.
405
406    Args:
407        provider_name: The provider to check
408
409    Raises:
410        ValueError: If required configuration keys are missing
411    """
412    warning_check = provider_warnings.get(provider_name)
413    if warning_check is None:
414        return
415    for key in warning_check.required_config_keys:
416        if get_config_value(key) is None:
417            raise ValueError(warning_check.message)
418
419
420def langchain_model_from(name: str, provider_name: str | None = None) -> BaseChatModel:
421    """
422    Creates a LangChain chat model instance for the specified model and provider.
423
424    Args:
425        name: The name of the model to instantiate
426        provider_name: Optional specific provider to use (defaults to first available)
427
428    Returns:
429        A configured LangChain chat model instance
430
431    Raises:
432        ValueError: If the model/provider combination is invalid or misconfigured
433    """
434    if name not in ModelName.__members__:
435        raise ValueError(f"Invalid name: {name}")
436
437    # Select the model from built_in_models using the name
438    model = next(filter(lambda m: m.name == name, built_in_models))
439    if model is None:
440        raise ValueError(f"Model {name} not found")
441
442    # If a provider is provided, select the provider from the model's provider_config
443    provider: KilnModelProvider | None = None
444    if model.providers is None or len(model.providers) == 0:
445        raise ValueError(f"Model {name} has no providers")
446    elif provider_name is None:
447        # TODO: priority order
448        provider = model.providers[0]
449    else:
450        provider = next(
451            filter(lambda p: p.name == provider_name, model.providers), None
452        )
453    if provider is None:
454        raise ValueError(f"Provider {provider_name} not found for model {name}")
455
456    check_provider_warnings(provider.name)
457
458    if provider.name == ModelProviderName.openai:
459        api_key = Config.shared().open_ai_api_key
460        return ChatOpenAI(**provider.provider_options, openai_api_key=api_key)  # type: ignore[arg-type]
461    elif provider.name == ModelProviderName.groq:
462        api_key = Config.shared().groq_api_key
463        if api_key is None:
464            raise ValueError(
465                "Attempted to use Groq without an API key set. "
466                "Get your API key from https://console.groq.com/keys"
467            )
468        return ChatGroq(**provider.provider_options, groq_api_key=api_key)  # type: ignore[arg-type]
469    elif provider.name == ModelProviderName.amazon_bedrock:
470        api_key = Config.shared().bedrock_access_key
471        secret_key = Config.shared().bedrock_secret_key
472        # langchain doesn't allow passing these, so ugly hack to set env vars
473        os.environ["AWS_ACCESS_KEY_ID"] = api_key
474        os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
475        return ChatBedrockConverse(
476            **provider.provider_options,
477        )
478    elif provider.name == ModelProviderName.ollama:
479        return ChatOllama(**provider.provider_options, base_url=ollama_base_url())
480    elif provider.name == ModelProviderName.openrouter:
481        api_key = Config.shared().open_router_api_key
482        base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
483        return ChatOpenAI(
484            **provider.provider_options,
485            openai_api_key=api_key,  # type: ignore[arg-type]
486            openai_api_base=base_url,  # type: ignore[arg-type]
487            default_headers={
488                "HTTP-Referer": "https://kiln-ai.com/openrouter",
489                "X-Title": "KilnAI",
490            },
491        )
492    else:
493        raise ValueError(f"Invalid model or provider: {name} - {provider_name}")
494
495
496def ollama_base_url() -> str:
497    """
498    Gets the base URL for Ollama API connections.
499
500    Returns:
501        The base URL to use for Ollama API calls, using environment variable if set
502        or falling back to localhost default
503    """
504    env_base_url = os.getenv("OLLAMA_BASE_URL")
505    if env_base_url is not None:
506        return env_base_url
507    return "http://localhost:11434"
508
509
510async def ollama_online() -> bool:
511    """
512    Checks if the Ollama service is available and responding.
513
514    Returns:
515        True if Ollama is available and responding, False otherwise
516    """
517    try:
518        httpx.get(ollama_base_url() + "/api/tags")
519    except httpx.RequestError:
520        return False
521    return True
class ModelProviderName(builtins.str, enum.Enum):
25class ModelProviderName(str, Enum):
26    """
27    Enumeration of supported AI model providers.
28    """
29
30    openai = "openai"
31    groq = "groq"
32    amazon_bedrock = "amazon_bedrock"
33    ollama = "ollama"
34    openrouter = "openrouter"

Enumeration of supported AI model providers.

openai = <ModelProviderName.openai: 'openai'>
groq = <ModelProviderName.groq: 'groq'>
amazon_bedrock = <ModelProviderName.amazon_bedrock: 'amazon_bedrock'>
ollama = <ModelProviderName.ollama: 'ollama'>
openrouter = <ModelProviderName.openrouter: 'openrouter'>
class ModelFamily(builtins.str, enum.Enum):
37class ModelFamily(str, Enum):
38    """
39    Enumeration of supported model families/architectures.
40    """
41
42    gpt = "gpt"
43    llama = "llama"
44    phi = "phi"
45    mistral = "mistral"
46    gemma = "gemma"

Enumeration of supported model families/architectures.

gpt = <ModelFamily.gpt: 'gpt'>
llama = <ModelFamily.llama: 'llama'>
phi = <ModelFamily.phi: 'phi'>
mistral = <ModelFamily.mistral: 'mistral'>
gemma = <ModelFamily.gemma: 'gemma'>
class ModelName(builtins.str, enum.Enum):
50class ModelName(str, Enum):
51    """
52    Enumeration of specific model versions supported by the system.
53    Where models have instruct and raw versions, instruct is default and raw is specified.
54    """
55
56    llama_3_1_8b = "llama_3_1_8b"
57    llama_3_1_70b = "llama_3_1_70b"
58    llama_3_1_405b = "llama_3_1_405b"
59    gpt_4o_mini = "gpt_4o_mini"
60    gpt_4o = "gpt_4o"
61    phi_3_5 = "phi_3_5"
62    mistral_large = "mistral_large"
63    mistral_nemo = "mistral_nemo"
64    gemma_2_2b = "gemma_2_2b"
65    gemma_2_9b = "gemma_2_9b"
66    gemma_2_27b = "gemma_2_27b"

Enumeration of specific model versions supported by the system. Where models have instruct and raw versions, instruct is default and raw is specified.

llama_3_1_8b = <ModelName.llama_3_1_8b: 'llama_3_1_8b'>
llama_3_1_70b = <ModelName.llama_3_1_70b: 'llama_3_1_70b'>
llama_3_1_405b = <ModelName.llama_3_1_405b: 'llama_3_1_405b'>
gpt_4o_mini = <ModelName.gpt_4o_mini: 'gpt_4o_mini'>
gpt_4o = <ModelName.gpt_4o: 'gpt_4o'>
phi_3_5 = <ModelName.phi_3_5: 'phi_3_5'>
mistral_large = <ModelName.mistral_large: 'mistral_large'>
mistral_nemo = <ModelName.mistral_nemo: 'mistral_nemo'>
gemma_2_2b = <ModelName.gemma_2_2b: 'gemma_2_2b'>
gemma_2_9b = <ModelName.gemma_2_9b: 'gemma_2_9b'>
gemma_2_27b = <ModelName.gemma_2_27b: 'gemma_2_27b'>
class KilnModelProvider(pydantic.main.BaseModel):
69class KilnModelProvider(BaseModel):
70    """
71    Configuration for a specific model provider.
72
73    Attributes:
74        name: The provider's identifier
75        supports_structured_output: Whether the provider supports structured output formats
76        provider_options: Additional provider-specific configuration options
77    """
78
79    name: ModelProviderName
80    supports_structured_output: bool = True
81    provider_options: Dict = {}

Configuration for a specific model provider.

Attributes: name: The provider's identifier supports_structured_output: Whether the provider supports structured output formats provider_options: Additional provider-specific configuration options

supports_structured_output: bool
provider_options: Dict
model_config = {}
model_fields = {'name': FieldInfo(annotation=ModelProviderName, required=True), 'supports_structured_output': FieldInfo(annotation=bool, required=False, default=True), 'provider_options': FieldInfo(annotation=Dict, required=False, default={})}
model_computed_fields = {}
class KilnModel(pydantic.main.BaseModel):
 84class KilnModel(BaseModel):
 85    """
 86    Configuration for a specific AI model.
 87
 88    Attributes:
 89        family: The model's architecture family
 90        name: The model's identifier
 91        friendly_name: Human-readable name for the model
 92        providers: List of providers that offer this model
 93        supports_structured_output: Whether the model supports structured output formats
 94    """
 95
 96    family: str
 97    name: str
 98    friendly_name: str
 99    providers: List[KilnModelProvider]
100    supports_structured_output: bool = True

Configuration for a specific AI model.

Attributes: family: The model's architecture family name: The model's identifier friendly_name: Human-readable name for the model providers: List of providers that offer this model supports_structured_output: Whether the model supports structured output formats

family: str
name: str
friendly_name: str
providers: List[KilnModelProvider]
supports_structured_output: bool
model_config = {}
model_fields = {'family': FieldInfo(annotation=str, required=True), 'name': FieldInfo(annotation=str, required=True), 'friendly_name': FieldInfo(annotation=str, required=True), 'providers': FieldInfo(annotation=List[KilnModelProvider], required=True), 'supports_structured_output': FieldInfo(annotation=bool, required=False, default=True)}
model_computed_fields = {}
built_in_models: List[KilnModel] = [KilnModel(family='gpt', name='gpt_4o_mini', friendly_name='GPT 4o Mini', providers=[KilnModelProvider(name=<ModelProviderName.openai: 'openai'>, supports_structured_output=True, provider_options={'model': 'gpt-4o-mini'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'openai/gpt-4o-mini'})], supports_structured_output=True), KilnModel(family='gpt', name='gpt_4o', friendly_name='GPT 4o', providers=[KilnModelProvider(name=<ModelProviderName.openai: 'openai'>, supports_structured_output=True, provider_options={'model': 'gpt-4o'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'openai/gpt-4o-2024-08-06'})], supports_structured_output=True), KilnModel(family='llama', name='llama_3_1_8b', friendly_name='Llama 3.1 8B', providers=[KilnModelProvider(name=<ModelProviderName.groq: 'groq'>, supports_structured_output=True, provider_options={'model': 'llama-3.1-8b-instant'}), KilnModelProvider(name=<ModelProviderName.amazon_bedrock: 'amazon_bedrock'>, supports_structured_output=True, provider_options={'model': 'meta.llama3-1-8b-instruct-v1:0', 'region_name': 'us-west-2'}), KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=True, provider_options={'model': 'llama3.1'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'meta-llama/llama-3.1-8b-instruct'})], supports_structured_output=True), KilnModel(family='llama', name='llama_3_1_70b', friendly_name='Llama 3.1 70B', providers=[KilnModelProvider(name=<ModelProviderName.groq: 'groq'>, supports_structured_output=True, provider_options={'model': 'llama-3.1-70b-versatile'}), KilnModelProvider(name=<ModelProviderName.amazon_bedrock: 'amazon_bedrock'>, supports_structured_output=False, provider_options={'model': 'meta.llama3-1-70b-instruct-v1:0', 'region_name': 'us-west-2'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'meta-llama/llama-3.1-70b-instruct'})], supports_structured_output=True), KilnModel(family='llama', name='llama_3_1_405b', friendly_name='Llama 3.1 405B', providers=[KilnModelProvider(name=<ModelProviderName.amazon_bedrock: 'amazon_bedrock'>, supports_structured_output=True, provider_options={'model': 'meta.llama3-1-405b-instruct-v1:0', 'region_name': 'us-west-2'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'meta-llama/llama-3.1-405b-instruct'})], supports_structured_output=True), KilnModel(family='mistral', name='mistral_nemo', friendly_name='Mistral Nemo', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'mistralai/mistral-nemo'})], supports_structured_output=True), KilnModel(family='mistral', name='mistral_large', friendly_name='Mistral Large', providers=[KilnModelProvider(name=<ModelProviderName.amazon_bedrock: 'amazon_bedrock'>, supports_structured_output=True, provider_options={'model': 'mistral.mistral-large-2407-v1:0', 'region_name': 'us-west-2'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'mistralai/mistral-large'})], supports_structured_output=True), KilnModel(family='phi', name='phi_3_5', friendly_name='Phi 3.5', providers=[KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=True, provider_options={'model': 'phi3.5'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'microsoft/phi-3.5-mini-128k-instruct'})], supports_structured_output=False), KilnModel(family='gemma', name='gemma_2_2b', friendly_name='Gemma 2 2B', providers=[KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=True, provider_options={'model': 'gemma2:2b'})], supports_structured_output=False), KilnModel(family='gemma', name='gemma_2_9b', friendly_name='Gemma 2 9B', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'google/gemma-2-9b-it'})], supports_structured_output=False), KilnModel(family='gemma', name='gemma_2_27b', friendly_name='Gemma 2 27B', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'google/gemma-2-27b-it'})], supports_structured_output=False)]
def provider_name_from_id(id: str) -> str:
333def provider_name_from_id(id: str) -> str:
334    """
335    Converts a provider ID to its human-readable name.
336
337    Args:
338        id: The provider identifier string
339
340    Returns:
341        The human-readable name of the provider
342
343    Raises:
344        ValueError: If the provider ID is invalid or unhandled
345    """
346    if id in ModelProviderName.__members__:
347        enum_id = ModelProviderName(id)
348        match enum_id:
349            case ModelProviderName.amazon_bedrock:
350                return "Amazon Bedrock"
351            case ModelProviderName.openrouter:
352                return "OpenRouter"
353            case ModelProviderName.groq:
354                return "Groq"
355            case ModelProviderName.ollama:
356                return "Ollama"
357            case ModelProviderName.openai:
358                return "OpenAI"
359            case _:
360                # triggers pyright warning if I miss a case
361                raise_exhaustive_error(enum_id)
362
363    return "Unknown provider: " + id

Converts a provider ID to its human-readable name.

Args: id: The provider identifier string

Returns: The human-readable name of the provider

Raises: ValueError: If the provider ID is invalid or unhandled

def raise_exhaustive_error(value: NoReturn) -> NoReturn:
366def raise_exhaustive_error(value: NoReturn) -> NoReturn:
367    raise ValueError(f"Unhandled enum value: {value}")
@dataclass
class ModelProviderWarning:
370@dataclass
371class ModelProviderWarning:
372    required_config_keys: List[str]
373    message: str
ModelProviderWarning(required_config_keys: List[str], message: str)
required_config_keys: List[str]
message: str
provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {<ModelProviderName.amazon_bedrock: 'amazon_bedrock'>: ModelProviderWarning(required_config_keys=['bedrock_access_key', 'bedrock_secret_key'], message='Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview'), <ModelProviderName.openrouter: 'openrouter'>: ModelProviderWarning(required_config_keys=['open_router_api_key'], message='Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys'), <ModelProviderName.groq: 'groq'>: ModelProviderWarning(required_config_keys=['groq_api_key'], message='Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys'), <ModelProviderName.openai: 'openai'>: ModelProviderWarning(required_config_keys=['open_ai_api_key'], message='Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys')}
def get_config_value(key: str):
396def get_config_value(key: str):
397    try:
398        return Config.shared().__getattr__(key)
399    except AttributeError:
400        return None
def check_provider_warnings(provider_name: ModelProviderName):
403def check_provider_warnings(provider_name: ModelProviderName):
404    """
405    Validates that required configuration is present for a given provider.
406
407    Args:
408        provider_name: The provider to check
409
410    Raises:
411        ValueError: If required configuration keys are missing
412    """
413    warning_check = provider_warnings.get(provider_name)
414    if warning_check is None:
415        return
416    for key in warning_check.required_config_keys:
417        if get_config_value(key) is None:
418            raise ValueError(warning_check.message)

Validates that required configuration is present for a given provider.

Args: provider_name: The provider to check

Raises: ValueError: If required configuration keys are missing

def langchain_model_from( name: str, provider_name: str | None = None) -> langchain_core.language_models.chat_models.BaseChatModel:
421def langchain_model_from(name: str, provider_name: str | None = None) -> BaseChatModel:
422    """
423    Creates a LangChain chat model instance for the specified model and provider.
424
425    Args:
426        name: The name of the model to instantiate
427        provider_name: Optional specific provider to use (defaults to first available)
428
429    Returns:
430        A configured LangChain chat model instance
431
432    Raises:
433        ValueError: If the model/provider combination is invalid or misconfigured
434    """
435    if name not in ModelName.__members__:
436        raise ValueError(f"Invalid name: {name}")
437
438    # Select the model from built_in_models using the name
439    model = next(filter(lambda m: m.name == name, built_in_models))
440    if model is None:
441        raise ValueError(f"Model {name} not found")
442
443    # If a provider is provided, select the provider from the model's provider_config
444    provider: KilnModelProvider | None = None
445    if model.providers is None or len(model.providers) == 0:
446        raise ValueError(f"Model {name} has no providers")
447    elif provider_name is None:
448        # TODO: priority order
449        provider = model.providers[0]
450    else:
451        provider = next(
452            filter(lambda p: p.name == provider_name, model.providers), None
453        )
454    if provider is None:
455        raise ValueError(f"Provider {provider_name} not found for model {name}")
456
457    check_provider_warnings(provider.name)
458
459    if provider.name == ModelProviderName.openai:
460        api_key = Config.shared().open_ai_api_key
461        return ChatOpenAI(**provider.provider_options, openai_api_key=api_key)  # type: ignore[arg-type]
462    elif provider.name == ModelProviderName.groq:
463        api_key = Config.shared().groq_api_key
464        if api_key is None:
465            raise ValueError(
466                "Attempted to use Groq without an API key set. "
467                "Get your API key from https://console.groq.com/keys"
468            )
469        return ChatGroq(**provider.provider_options, groq_api_key=api_key)  # type: ignore[arg-type]
470    elif provider.name == ModelProviderName.amazon_bedrock:
471        api_key = Config.shared().bedrock_access_key
472        secret_key = Config.shared().bedrock_secret_key
473        # langchain doesn't allow passing these, so ugly hack to set env vars
474        os.environ["AWS_ACCESS_KEY_ID"] = api_key
475        os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
476        return ChatBedrockConverse(
477            **provider.provider_options,
478        )
479    elif provider.name == ModelProviderName.ollama:
480        return ChatOllama(**provider.provider_options, base_url=ollama_base_url())
481    elif provider.name == ModelProviderName.openrouter:
482        api_key = Config.shared().open_router_api_key
483        base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
484        return ChatOpenAI(
485            **provider.provider_options,
486            openai_api_key=api_key,  # type: ignore[arg-type]
487            openai_api_base=base_url,  # type: ignore[arg-type]
488            default_headers={
489                "HTTP-Referer": "https://kiln-ai.com/openrouter",
490                "X-Title": "KilnAI",
491            },
492        )
493    else:
494        raise ValueError(f"Invalid model or provider: {name} - {provider_name}")

Creates a LangChain chat model instance for the specified model and provider.

Args: name: The name of the model to instantiate provider_name: Optional specific provider to use (defaults to first available)

Returns: A configured LangChain chat model instance

Raises: ValueError: If the model/provider combination is invalid or misconfigured

def ollama_base_url() -> str:
497def ollama_base_url() -> str:
498    """
499    Gets the base URL for Ollama API connections.
500
501    Returns:
502        The base URL to use for Ollama API calls, using environment variable if set
503        or falling back to localhost default
504    """
505    env_base_url = os.getenv("OLLAMA_BASE_URL")
506    if env_base_url is not None:
507        return env_base_url
508    return "http://localhost:11434"

Gets the base URL for Ollama API connections.

Returns: The base URL to use for Ollama API calls, using environment variable if set or falling back to localhost default

async def ollama_online() -> bool:
511async def ollama_online() -> bool:
512    """
513    Checks if the Ollama service is available and responding.
514
515    Returns:
516        True if Ollama is available and responding, False otherwise
517    """
518    try:
519        httpx.get(ollama_base_url() + "/api/tags")
520    except httpx.RequestError:
521        return False
522    return True

Checks if the Ollama service is available and responding.

Returns: True if Ollama is available and responding, False otherwise