kiln_ai.adapters.ml_model_list

  1import os
  2from dataclasses import dataclass
  3from enum import Enum
  4from os import getenv
  5from typing import Any, Dict, List, NoReturn
  6
  7import httpx
  8import requests
  9from langchain_aws import ChatBedrockConverse
 10from langchain_core.language_models.chat_models import BaseChatModel
 11from langchain_groq import ChatGroq
 12from langchain_ollama import ChatOllama
 13from langchain_openai import ChatOpenAI
 14from pydantic import BaseModel
 15
 16from ..utils.config import Config
 17
 18"""
 19Provides model configuration and management for various LLM providers and models.
 20This module handles the integration with different AI model providers and their respective models,
 21including configuration, validation, and instantiation of language models.
 22"""
 23
 24
 25class ModelProviderName(str, Enum):
 26    """
 27    Enumeration of supported AI model providers.
 28    """
 29
 30    openai = "openai"
 31    groq = "groq"
 32    amazon_bedrock = "amazon_bedrock"
 33    ollama = "ollama"
 34    openrouter = "openrouter"
 35
 36
 37class ModelFamily(str, Enum):
 38    """
 39    Enumeration of supported model families/architectures.
 40    """
 41
 42    gpt = "gpt"
 43    llama = "llama"
 44    phi = "phi"
 45    mistral = "mistral"
 46    gemma = "gemma"
 47    gemini = "gemini"
 48    claude = "claude"
 49
 50
 51# Where models have instruct and raw versions, instruct is default and raw is specified
 52class ModelName(str, Enum):
 53    """
 54    Enumeration of specific model versions supported by the system.
 55    Where models have instruct and raw versions, instruct is default and raw is specified.
 56    """
 57
 58    llama_3_1_8b = "llama_3_1_8b"
 59    llama_3_1_70b = "llama_3_1_70b"
 60    llama_3_1_405b = "llama_3_1_405b"
 61    llama_3_2_3b = "llama_3_2_3b"
 62    llama_3_2_11b = "llama_3_2_11b"
 63    llama_3_2_90b = "llama_3_2_90b"
 64    gpt_4o_mini = "gpt_4o_mini"
 65    gpt_4o = "gpt_4o"
 66    phi_3_5 = "phi_3_5"
 67    mistral_large = "mistral_large"
 68    mistral_nemo = "mistral_nemo"
 69    gemma_2_2b = "gemma_2_2b"
 70    gemma_2_9b = "gemma_2_9b"
 71    gemma_2_27b = "gemma_2_27b"
 72    claude_3_5_haiku = "claude_3_5_haiku"
 73    claude_3_5_sonnet = "claude_3_5_sonnet"
 74    gemini_1_5_flash = "gemini_1_5_flash"
 75    gemini_1_5_flash_8b = "gemini_1_5_flash_8b"
 76    gemini_1_5_pro = "gemini_1_5_pro"
 77    nemotron_70b = "nemotron_70b"
 78
 79
 80class KilnModelProvider(BaseModel):
 81    """
 82    Configuration for a specific model provider.
 83
 84    Attributes:
 85        name: The provider's identifier
 86        supports_structured_output: Whether the provider supports structured output formats
 87        provider_options: Additional provider-specific configuration options
 88    """
 89
 90    name: ModelProviderName
 91    supports_structured_output: bool = True
 92    supports_data_gen: bool = True
 93    provider_finetune_id: str | None = None
 94    provider_options: Dict = {}
 95
 96
 97class KilnModel(BaseModel):
 98    """
 99    Configuration for a specific AI model.
100
101    Attributes:
102        family: The model's architecture family
103        name: The model's identifier
104        friendly_name: Human-readable name for the model
105        providers: List of providers that offer this model
106        supports_structured_output: Whether the model supports structured output formats
107    """
108
109    family: str
110    name: str
111    friendly_name: str
112    providers: List[KilnModelProvider]
113    supports_structured_output: bool = True
114
115
116built_in_models: List[KilnModel] = [
117    # GPT 4o Mini
118    KilnModel(
119        family=ModelFamily.gpt,
120        name=ModelName.gpt_4o_mini,
121        friendly_name="GPT 4o Mini",
122        providers=[
123            KilnModelProvider(
124                name=ModelProviderName.openai,
125                provider_options={"model": "gpt-4o-mini"},
126                provider_finetune_id="gpt-4o-mini-2024-07-18",
127            ),
128            KilnModelProvider(
129                name=ModelProviderName.openrouter,
130                provider_options={"model": "openai/gpt-4o-mini"},
131            ),
132        ],
133    ),
134    # GPT 4o
135    KilnModel(
136        family=ModelFamily.gpt,
137        name=ModelName.gpt_4o,
138        friendly_name="GPT 4o",
139        providers=[
140            KilnModelProvider(
141                name=ModelProviderName.openai,
142                provider_options={"model": "gpt-4o"},
143                provider_finetune_id="gpt-4o-2024-08-06",
144            ),
145            KilnModelProvider(
146                name=ModelProviderName.openrouter,
147                provider_options={"model": "openai/gpt-4o-2024-08-06"},
148            ),
149        ],
150    ),
151    # Claude 3.5 Haiku
152    KilnModel(
153        family=ModelFamily.claude,
154        name=ModelName.claude_3_5_haiku,
155        friendly_name="Claude 3.5 Haiku",
156        providers=[
157            KilnModelProvider(
158                name=ModelProviderName.openrouter,
159                provider_options={"model": "anthropic/claude-3-5-haiku"},
160            ),
161        ],
162    ),
163    # Claude 3.5 Sonnet
164    KilnModel(
165        family=ModelFamily.claude,
166        name=ModelName.claude_3_5_sonnet,
167        friendly_name="Claude 3.5 Sonnet",
168        providers=[
169            KilnModelProvider(
170                name=ModelProviderName.openrouter,
171                provider_options={"model": "anthropic/claude-3.5-sonnet"},
172            ),
173        ],
174    ),
175    # Gemini 1.5 Pro
176    KilnModel(
177        family=ModelFamily.gemini,
178        name=ModelName.gemini_1_5_pro,
179        friendly_name="Gemini 1.5 Pro",
180        providers=[
181            KilnModelProvider(
182                name=ModelProviderName.openrouter,
183                supports_structured_output=False,  # it should, but doesn't work on openrouter
184                supports_data_gen=False,  # doesn't work on openrouter
185                provider_options={"model": "google/gemini-pro-1.5"},
186            ),
187        ],
188    ),
189    # Gemini 1.5 Flash
190    KilnModel(
191        family=ModelFamily.gemini,
192        name=ModelName.gemini_1_5_flash,
193        friendly_name="Gemini 1.5 Flash",
194        providers=[
195            KilnModelProvider(
196                name=ModelProviderName.openrouter,
197                supports_data_gen=False,
198                provider_options={"model": "google/gemini-flash-1.5"},
199            ),
200        ],
201    ),
202    # Gemini 1.5 Flash 8B
203    KilnModel(
204        family=ModelFamily.gemini,
205        name=ModelName.gemini_1_5_flash_8b,
206        friendly_name="Gemini 1.5 Flash 8B",
207        providers=[
208            KilnModelProvider(
209                name=ModelProviderName.openrouter,
210                supports_structured_output=False,
211                supports_data_gen=False,
212                provider_options={"model": "google/gemini-flash-1.5-8b"},
213            ),
214        ],
215    ),
216    # Nemotron 70B
217    KilnModel(
218        family=ModelFamily.llama,
219        name=ModelName.nemotron_70b,
220        friendly_name="Nemotron 70B",
221        providers=[
222            KilnModelProvider(
223                name=ModelProviderName.openrouter,
224                supports_structured_output=False,
225                supports_data_gen=False,
226                provider_options={"model": "nvidia/llama-3.1-nemotron-70b-instruct"},
227            ),
228        ],
229    ),
230    # Llama 3.1-8b
231    KilnModel(
232        family=ModelFamily.llama,
233        name=ModelName.llama_3_1_8b,
234        friendly_name="Llama 3.1 8B",
235        providers=[
236            KilnModelProvider(
237                name=ModelProviderName.groq,
238                provider_options={"model": "llama-3.1-8b-instant"},
239            ),
240            KilnModelProvider(
241                name=ModelProviderName.amazon_bedrock,
242                supports_structured_output=False,
243                supports_data_gen=False,
244                provider_options={
245                    "model": "meta.llama3-1-8b-instruct-v1:0",
246                    "region_name": "us-west-2",  # Llama 3.1 only in west-2
247                },
248            ),
249            KilnModelProvider(
250                name=ModelProviderName.ollama,
251                supports_data_gen=False,
252                provider_options={
253                    "model": "llama3.1:8b",
254                    "model_aliases": ["llama3.1"],  # 8b is default
255                },
256            ),
257            KilnModelProvider(
258                name=ModelProviderName.openrouter,
259                supports_structured_output=False,
260                supports_data_gen=False,
261                provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
262            ),
263        ],
264    ),
265    # Llama 3.1 70b
266    KilnModel(
267        family=ModelFamily.llama,
268        name=ModelName.llama_3_1_70b,
269        friendly_name="Llama 3.1 70B",
270        providers=[
271            KilnModelProvider(
272                name=ModelProviderName.groq,
273                provider_options={"model": "llama-3.1-70b-versatile"},
274            ),
275            KilnModelProvider(
276                name=ModelProviderName.amazon_bedrock,
277                # not sure how AWS manages to break this, but it's not working
278                supports_structured_output=False,
279                supports_data_gen=False,
280                provider_options={
281                    "model": "meta.llama3-1-70b-instruct-v1:0",
282                    "region_name": "us-west-2",  # Llama 3.1 only in west-2
283                },
284            ),
285            KilnModelProvider(
286                name=ModelProviderName.openrouter,
287                provider_options={"model": "meta-llama/llama-3.1-70b-instruct"},
288            ),
289            KilnModelProvider(
290                name=ModelProviderName.ollama,
291                provider_options={"model": "llama3.1:70b"},
292            ),
293        ],
294    ),
295    # Llama 3.1 405b
296    KilnModel(
297        family=ModelFamily.llama,
298        name=ModelName.llama_3_1_405b,
299        friendly_name="Llama 3.1 405B",
300        providers=[
301            KilnModelProvider(
302                name=ModelProviderName.amazon_bedrock,
303                supports_data_gen=False,
304                provider_options={
305                    "model": "meta.llama3-1-405b-instruct-v1:0",
306                    "region_name": "us-west-2",  # Llama 3.1 only in west-2
307                },
308            ),
309            KilnModelProvider(
310                name=ModelProviderName.ollama,
311                provider_options={"model": "llama3.1:405b"},
312            ),
313            KilnModelProvider(
314                name=ModelProviderName.openrouter,
315                provider_options={"model": "meta-llama/llama-3.1-405b-instruct"},
316            ),
317        ],
318    ),
319    # Mistral Nemo
320    KilnModel(
321        family=ModelFamily.mistral,
322        name=ModelName.mistral_nemo,
323        friendly_name="Mistral Nemo",
324        providers=[
325            KilnModelProvider(
326                name=ModelProviderName.openrouter,
327                provider_options={"model": "mistralai/mistral-nemo"},
328            ),
329        ],
330    ),
331    # Mistral Large
332    KilnModel(
333        family=ModelFamily.mistral,
334        name=ModelName.mistral_large,
335        friendly_name="Mistral Large",
336        providers=[
337            KilnModelProvider(
338                name=ModelProviderName.amazon_bedrock,
339                provider_options={
340                    "model": "mistral.mistral-large-2407-v1:0",
341                    "region_name": "us-west-2",  # only in west-2
342                },
343            ),
344            KilnModelProvider(
345                name=ModelProviderName.openrouter,
346                provider_options={"model": "mistralai/mistral-large"},
347            ),
348            KilnModelProvider(
349                name=ModelProviderName.ollama,
350                provider_options={"model": "mistral-large"},
351            ),
352        ],
353    ),
354    # Llama 3.2 3B
355    KilnModel(
356        family=ModelFamily.llama,
357        name=ModelName.llama_3_2_3b,
358        friendly_name="Llama 3.2 3B",
359        providers=[
360            KilnModelProvider(
361                name=ModelProviderName.openrouter,
362                supports_structured_output=False,
363                supports_data_gen=False,
364                provider_options={"model": "meta-llama/llama-3.2-3b-instruct"},
365            ),
366            KilnModelProvider(
367                name=ModelProviderName.ollama,
368                supports_structured_output=False,
369                supports_data_gen=False,
370                provider_options={"model": "llama3.2"},
371            ),
372        ],
373    ),
374    # Llama 3.2 11B
375    KilnModel(
376        family=ModelFamily.llama,
377        name=ModelName.llama_3_2_11b,
378        friendly_name="Llama 3.2 11B",
379        providers=[
380            KilnModelProvider(
381                name=ModelProviderName.openrouter,
382                supports_structured_output=False,
383                supports_data_gen=False,
384                provider_options={"model": "meta-llama/llama-3.2-11b-vision-instruct"},
385            ),
386            KilnModelProvider(
387                name=ModelProviderName.ollama,
388                supports_structured_output=False,
389                supports_data_gen=False,
390                provider_options={"model": "llama3.2-vision"},
391            ),
392        ],
393    ),
394    # Llama 3.2 90B
395    KilnModel(
396        family=ModelFamily.llama,
397        name=ModelName.llama_3_2_90b,
398        friendly_name="Llama 3.2 90B",
399        providers=[
400            KilnModelProvider(
401                name=ModelProviderName.openrouter,
402                supports_structured_output=False,
403                supports_data_gen=False,
404                provider_options={"model": "meta-llama/llama-3.2-90b-vision-instruct"},
405            ),
406            KilnModelProvider(
407                name=ModelProviderName.ollama,
408                supports_structured_output=False,
409                supports_data_gen=False,
410                provider_options={"model": "llama3.2-vision:90b"},
411            ),
412        ],
413    ),
414    # Phi 3.5
415    KilnModel(
416        family=ModelFamily.phi,
417        name=ModelName.phi_3_5,
418        friendly_name="Phi 3.5",
419        supports_structured_output=False,
420        providers=[
421            KilnModelProvider(
422                name=ModelProviderName.ollama,
423                supports_structured_output=False,
424                supports_data_gen=False,
425                provider_options={"model": "phi3.5"},
426            ),
427            KilnModelProvider(
428                name=ModelProviderName.openrouter,
429                supports_structured_output=False,
430                supports_data_gen=False,
431                provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"},
432            ),
433        ],
434    ),
435    # Gemma 2 2.6b
436    KilnModel(
437        family=ModelFamily.gemma,
438        name=ModelName.gemma_2_2b,
439        friendly_name="Gemma 2 2B",
440        supports_structured_output=False,
441        providers=[
442            KilnModelProvider(
443                name=ModelProviderName.ollama,
444                supports_structured_output=False,
445                supports_data_gen=False,
446                provider_options={
447                    "model": "gemma2:2b",
448                },
449            ),
450        ],
451    ),
452    # Gemma 2 9b
453    KilnModel(
454        family=ModelFamily.gemma,
455        name=ModelName.gemma_2_9b,
456        friendly_name="Gemma 2 9B",
457        supports_structured_output=False,
458        providers=[
459            KilnModelProvider(
460                name=ModelProviderName.ollama,
461                supports_data_gen=False,
462                provider_options={
463                    "model": "gemma2:9b",
464                },
465            ),
466            KilnModelProvider(
467                name=ModelProviderName.openrouter,
468                supports_data_gen=False,
469                provider_options={"model": "google/gemma-2-9b-it"},
470            ),
471        ],
472    ),
473    # Gemma 2 27b
474    KilnModel(
475        family=ModelFamily.gemma,
476        name=ModelName.gemma_2_27b,
477        friendly_name="Gemma 2 27B",
478        supports_structured_output=False,
479        providers=[
480            KilnModelProvider(
481                name=ModelProviderName.ollama,
482                supports_data_gen=False,
483                provider_options={
484                    "model": "gemma2:27b",
485                },
486            ),
487            KilnModelProvider(
488                name=ModelProviderName.openrouter,
489                supports_data_gen=False,
490                provider_options={"model": "google/gemma-2-27b-it"},
491            ),
492        ],
493    ),
494]
495
496
497def get_model_and_provider(
498    model_name: str, provider_name: str
499) -> tuple[KilnModel | None, KilnModelProvider | None]:
500    model = next(filter(lambda m: m.name == model_name, built_in_models), None)
501    if model is None:
502        return None, None
503    provider = next(filter(lambda p: p.name == provider_name, model.providers), None)
504    # all or nothing
505    if provider is None or model is None:
506        return None, None
507    return model, provider
508
509
510def provider_name_from_id(id: str) -> str:
511    """
512    Converts a provider ID to its human-readable name.
513
514    Args:
515        id: The provider identifier string
516
517    Returns:
518        The human-readable name of the provider
519
520    Raises:
521        ValueError: If the provider ID is invalid or unhandled
522    """
523    if id in ModelProviderName.__members__:
524        enum_id = ModelProviderName(id)
525        match enum_id:
526            case ModelProviderName.amazon_bedrock:
527                return "Amazon Bedrock"
528            case ModelProviderName.openrouter:
529                return "OpenRouter"
530            case ModelProviderName.groq:
531                return "Groq"
532            case ModelProviderName.ollama:
533                return "Ollama"
534            case ModelProviderName.openai:
535                return "OpenAI"
536            case _:
537                # triggers pyright warning if I miss a case
538                raise_exhaustive_error(enum_id)
539
540    return "Unknown provider: " + id
541
542
543def raise_exhaustive_error(value: NoReturn) -> NoReturn:
544    raise ValueError(f"Unhandled enum value: {value}")
545
546
547@dataclass
548class ModelProviderWarning:
549    required_config_keys: List[str]
550    message: str
551
552
553provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
554    ModelProviderName.amazon_bedrock: ModelProviderWarning(
555        required_config_keys=["bedrock_access_key", "bedrock_secret_key"],
556        message="Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview",
557    ),
558    ModelProviderName.openrouter: ModelProviderWarning(
559        required_config_keys=["open_router_api_key"],
560        message="Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys",
561    ),
562    ModelProviderName.groq: ModelProviderWarning(
563        required_config_keys=["groq_api_key"],
564        message="Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys",
565    ),
566    ModelProviderName.openai: ModelProviderWarning(
567        required_config_keys=["open_ai_api_key"],
568        message="Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys",
569    ),
570}
571
572
573async def provider_enabled(provider_name: ModelProviderName) -> bool:
574    if provider_name == ModelProviderName.ollama:
575        try:
576            tags = await get_ollama_connection()
577            return tags is not None and len(tags.models) > 0
578        except Exception:
579            return False
580
581    provider_warning = provider_warnings.get(provider_name)
582    if provider_warning is None:
583        return False
584    for required_key in provider_warning.required_config_keys:
585        if get_config_value(required_key) is None:
586            return False
587    return True
588
589
590def get_config_value(key: str):
591    try:
592        return Config.shared().__getattr__(key)
593    except AttributeError:
594        return None
595
596
597def check_provider_warnings(provider_name: ModelProviderName):
598    """
599    Validates that required configuration is present for a given provider.
600
601    Args:
602        provider_name: The provider to check
603
604    Raises:
605        ValueError: If required configuration keys are missing
606    """
607    warning_check = provider_warnings.get(provider_name)
608    if warning_check is None:
609        return
610    for key in warning_check.required_config_keys:
611        if get_config_value(key) is None:
612            raise ValueError(warning_check.message)
613
614
615async def langchain_model_from(
616    name: str, provider_name: str | None = None
617) -> BaseChatModel:
618    """
619    Creates a LangChain chat model instance for the specified model and provider.
620
621    Args:
622        name: The name of the model to instantiate
623        provider_name: Optional specific provider to use (defaults to first available)
624
625    Returns:
626        A configured LangChain chat model instance
627
628    Raises:
629        ValueError: If the model/provider combination is invalid or misconfigured
630    """
631    if name not in ModelName.__members__:
632        raise ValueError(f"Invalid name: {name}")
633
634    # Select the model from built_in_models using the name
635    model = next(filter(lambda m: m.name == name, built_in_models))
636    if model is None:
637        raise ValueError(f"Model {name} not found")
638
639    # If a provider is provided, select the provider from the model's provider_config
640    provider: KilnModelProvider | None = None
641    if model.providers is None or len(model.providers) == 0:
642        raise ValueError(f"Model {name} has no providers")
643    elif provider_name is None:
644        # TODO: priority order
645        provider = model.providers[0]
646    else:
647        provider = next(
648            filter(lambda p: p.name == provider_name, model.providers), None
649        )
650    if provider is None:
651        raise ValueError(f"Provider {provider_name} not found for model {name}")
652
653    check_provider_warnings(provider.name)
654
655    if provider.name == ModelProviderName.openai:
656        api_key = Config.shared().open_ai_api_key
657        return ChatOpenAI(**provider.provider_options, openai_api_key=api_key)  # type: ignore[arg-type]
658    elif provider.name == ModelProviderName.groq:
659        api_key = Config.shared().groq_api_key
660        if api_key is None:
661            raise ValueError(
662                "Attempted to use Groq without an API key set. "
663                "Get your API key from https://console.groq.com/keys"
664            )
665        return ChatGroq(**provider.provider_options, groq_api_key=api_key)  # type: ignore[arg-type]
666    elif provider.name == ModelProviderName.amazon_bedrock:
667        api_key = Config.shared().bedrock_access_key
668        secret_key = Config.shared().bedrock_secret_key
669        # langchain doesn't allow passing these, so ugly hack to set env vars
670        os.environ["AWS_ACCESS_KEY_ID"] = api_key
671        os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
672        return ChatBedrockConverse(
673            **provider.provider_options,
674        )
675    elif provider.name == ModelProviderName.ollama:
676        # Ollama model naming is pretty flexible. We try a few versions of the model name
677        potential_model_names = []
678        if "model" in provider.provider_options:
679            potential_model_names.append(provider.provider_options["model"])
680        if "model_aliases" in provider.provider_options:
681            potential_model_names.extend(provider.provider_options["model_aliases"])
682
683        # Get the list of models Ollama supports
684        ollama_connection = await get_ollama_connection()
685        if ollama_connection is None:
686            raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
687
688        for model_name in potential_model_names:
689            if ollama_model_supported(ollama_connection, model_name):
690                return ChatOllama(model=model_name, base_url=ollama_base_url())
691
692        raise ValueError(f"Model {name} not installed on Ollama")
693    elif provider.name == ModelProviderName.openrouter:
694        api_key = Config.shared().open_router_api_key
695        base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
696        return ChatOpenAI(
697            **provider.provider_options,
698            openai_api_key=api_key,  # type: ignore[arg-type]
699            openai_api_base=base_url,  # type: ignore[arg-type]
700            default_headers={
701                "HTTP-Referer": "https://getkiln.ai/openrouter",
702                "X-Title": "KilnAI",
703            },
704        )
705    else:
706        raise ValueError(f"Invalid model or provider: {name} - {provider_name}")
707
708
709def ollama_base_url() -> str:
710    """
711    Gets the base URL for Ollama API connections.
712
713    Returns:
714        The base URL to use for Ollama API calls, using environment variable if set
715        or falling back to localhost default
716    """
717    env_base_url = os.getenv("OLLAMA_BASE_URL")
718    if env_base_url is not None:
719        return env_base_url
720    return "http://localhost:11434"
721
722
723async def ollama_online() -> bool:
724    """
725    Checks if the Ollama service is available and responding.
726
727    Returns:
728        True if Ollama is available and responding, False otherwise
729    """
730    try:
731        httpx.get(ollama_base_url() + "/api/tags")
732    except httpx.RequestError:
733        return False
734    return True
735
736
737class OllamaConnection(BaseModel):
738    message: str
739    models: List[str]
740
741
742# Parse the Ollama /api/tags response
743def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
744    # Build a list of models we support for Ollama from the built-in model list
745    supported_ollama_models = [
746        provider.provider_options["model"]
747        for model in built_in_models
748        for provider in model.providers
749        if provider.name == ModelProviderName.ollama
750    ]
751    # Append model_aliases to supported_ollama_models
752    supported_ollama_models.extend(
753        [
754            alias
755            for model in built_in_models
756            for provider in model.providers
757            for alias in provider.provider_options.get("model_aliases", [])
758        ]
759    )
760
761    if "models" in tags:
762        models = tags["models"]
763        if isinstance(models, list):
764            model_names = [model["model"] for model in models]
765            available_supported_models = [
766                model
767                for model in model_names
768                if model in supported_ollama_models
769                or model in [f"{m}:latest" for m in supported_ollama_models]
770            ]
771            if available_supported_models:
772                return OllamaConnection(
773                    message="Ollama connected",
774                    models=available_supported_models,
775                )
776
777    return OllamaConnection(
778        message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.",
779        models=[],
780    )
781
782
783async def get_ollama_connection() -> OllamaConnection | None:
784    """
785    Gets the connection status for Ollama.
786    """
787    try:
788        tags = requests.get(ollama_base_url() + "/api/tags", timeout=5).json()
789
790    except Exception:
791        return None
792
793    return parse_ollama_tags(tags)
794
795
796def ollama_model_supported(conn: OllamaConnection, model_name: str) -> bool:
797    return model_name in conn.models or f"{model_name}:latest" in conn.models
class ModelProviderName(builtins.str, enum.Enum):
26class ModelProviderName(str, Enum):
27    """
28    Enumeration of supported AI model providers.
29    """
30
31    openai = "openai"
32    groq = "groq"
33    amazon_bedrock = "amazon_bedrock"
34    ollama = "ollama"
35    openrouter = "openrouter"

Enumeration of supported AI model providers.

openai = <ModelProviderName.openai: 'openai'>
groq = <ModelProviderName.groq: 'groq'>
amazon_bedrock = <ModelProviderName.amazon_bedrock: 'amazon_bedrock'>
ollama = <ModelProviderName.ollama: 'ollama'>
openrouter = <ModelProviderName.openrouter: 'openrouter'>
class ModelFamily(builtins.str, enum.Enum):
38class ModelFamily(str, Enum):
39    """
40    Enumeration of supported model families/architectures.
41    """
42
43    gpt = "gpt"
44    llama = "llama"
45    phi = "phi"
46    mistral = "mistral"
47    gemma = "gemma"
48    gemini = "gemini"
49    claude = "claude"

Enumeration of supported model families/architectures.

gpt = <ModelFamily.gpt: 'gpt'>
llama = <ModelFamily.llama: 'llama'>
phi = <ModelFamily.phi: 'phi'>
mistral = <ModelFamily.mistral: 'mistral'>
gemma = <ModelFamily.gemma: 'gemma'>
gemini = <ModelFamily.gemini: 'gemini'>
claude = <ModelFamily.claude: 'claude'>
class ModelName(builtins.str, enum.Enum):
53class ModelName(str, Enum):
54    """
55    Enumeration of specific model versions supported by the system.
56    Where models have instruct and raw versions, instruct is default and raw is specified.
57    """
58
59    llama_3_1_8b = "llama_3_1_8b"
60    llama_3_1_70b = "llama_3_1_70b"
61    llama_3_1_405b = "llama_3_1_405b"
62    llama_3_2_3b = "llama_3_2_3b"
63    llama_3_2_11b = "llama_3_2_11b"
64    llama_3_2_90b = "llama_3_2_90b"
65    gpt_4o_mini = "gpt_4o_mini"
66    gpt_4o = "gpt_4o"
67    phi_3_5 = "phi_3_5"
68    mistral_large = "mistral_large"
69    mistral_nemo = "mistral_nemo"
70    gemma_2_2b = "gemma_2_2b"
71    gemma_2_9b = "gemma_2_9b"
72    gemma_2_27b = "gemma_2_27b"
73    claude_3_5_haiku = "claude_3_5_haiku"
74    claude_3_5_sonnet = "claude_3_5_sonnet"
75    gemini_1_5_flash = "gemini_1_5_flash"
76    gemini_1_5_flash_8b = "gemini_1_5_flash_8b"
77    gemini_1_5_pro = "gemini_1_5_pro"
78    nemotron_70b = "nemotron_70b"

Enumeration of specific model versions supported by the system. Where models have instruct and raw versions, instruct is default and raw is specified.

llama_3_1_8b = <ModelName.llama_3_1_8b: 'llama_3_1_8b'>
llama_3_1_70b = <ModelName.llama_3_1_70b: 'llama_3_1_70b'>
llama_3_1_405b = <ModelName.llama_3_1_405b: 'llama_3_1_405b'>
llama_3_2_3b = <ModelName.llama_3_2_3b: 'llama_3_2_3b'>
llama_3_2_11b = <ModelName.llama_3_2_11b: 'llama_3_2_11b'>
llama_3_2_90b = <ModelName.llama_3_2_90b: 'llama_3_2_90b'>
gpt_4o_mini = <ModelName.gpt_4o_mini: 'gpt_4o_mini'>
gpt_4o = <ModelName.gpt_4o: 'gpt_4o'>
phi_3_5 = <ModelName.phi_3_5: 'phi_3_5'>
mistral_large = <ModelName.mistral_large: 'mistral_large'>
mistral_nemo = <ModelName.mistral_nemo: 'mistral_nemo'>
gemma_2_2b = <ModelName.gemma_2_2b: 'gemma_2_2b'>
gemma_2_9b = <ModelName.gemma_2_9b: 'gemma_2_9b'>
gemma_2_27b = <ModelName.gemma_2_27b: 'gemma_2_27b'>
claude_3_5_haiku = <ModelName.claude_3_5_haiku: 'claude_3_5_haiku'>
claude_3_5_sonnet = <ModelName.claude_3_5_sonnet: 'claude_3_5_sonnet'>
gemini_1_5_flash = <ModelName.gemini_1_5_flash: 'gemini_1_5_flash'>
gemini_1_5_flash_8b = <ModelName.gemini_1_5_flash_8b: 'gemini_1_5_flash_8b'>
gemini_1_5_pro = <ModelName.gemini_1_5_pro: 'gemini_1_5_pro'>
nemotron_70b = <ModelName.nemotron_70b: 'nemotron_70b'>
class KilnModelProvider(pydantic.main.BaseModel):
81class KilnModelProvider(BaseModel):
82    """
83    Configuration for a specific model provider.
84
85    Attributes:
86        name: The provider's identifier
87        supports_structured_output: Whether the provider supports structured output formats
88        provider_options: Additional provider-specific configuration options
89    """
90
91    name: ModelProviderName
92    supports_structured_output: bool = True
93    supports_data_gen: bool = True
94    provider_finetune_id: str | None = None
95    provider_options: Dict = {}

Configuration for a specific model provider.

Attributes: name: The provider's identifier supports_structured_output: Whether the provider supports structured output formats provider_options: Additional provider-specific configuration options

supports_structured_output: bool
supports_data_gen: bool
provider_finetune_id: str | None
provider_options: Dict
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_fields: ClassVar[Dict[str, pydantic.fields.FieldInfo]] = {'name': FieldInfo(annotation=ModelProviderName, required=True), 'supports_structured_output': FieldInfo(annotation=bool, required=False, default=True), 'supports_data_gen': FieldInfo(annotation=bool, required=False, default=True), 'provider_finetune_id': FieldInfo(annotation=Union[str, NoneType], required=False, default=None), 'provider_options': FieldInfo(annotation=Dict, required=False, default={})}

Metadata about the fields defined on the model, mapping of field names to [FieldInfo][pydantic.fields.FieldInfo] objects.

This replaces Model.__fields__ from Pydantic V1.

model_computed_fields: ClassVar[Dict[str, pydantic.fields.ComputedFieldInfo]] = {}

A dictionary of computed field names and their corresponding ComputedFieldInfo objects.

class KilnModel(pydantic.main.BaseModel):
 98class KilnModel(BaseModel):
 99    """
100    Configuration for a specific AI model.
101
102    Attributes:
103        family: The model's architecture family
104        name: The model's identifier
105        friendly_name: Human-readable name for the model
106        providers: List of providers that offer this model
107        supports_structured_output: Whether the model supports structured output formats
108    """
109
110    family: str
111    name: str
112    friendly_name: str
113    providers: List[KilnModelProvider]
114    supports_structured_output: bool = True

Configuration for a specific AI model.

Attributes: family: The model's architecture family name: The model's identifier friendly_name: Human-readable name for the model providers: List of providers that offer this model supports_structured_output: Whether the model supports structured output formats

family: str
name: str
friendly_name: str
providers: List[KilnModelProvider]
supports_structured_output: bool
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_fields: ClassVar[Dict[str, pydantic.fields.FieldInfo]] = {'family': FieldInfo(annotation=str, required=True), 'name': FieldInfo(annotation=str, required=True), 'friendly_name': FieldInfo(annotation=str, required=True), 'providers': FieldInfo(annotation=List[KilnModelProvider], required=True), 'supports_structured_output': FieldInfo(annotation=bool, required=False, default=True)}

Metadata about the fields defined on the model, mapping of field names to [FieldInfo][pydantic.fields.FieldInfo] objects.

This replaces Model.__fields__ from Pydantic V1.

model_computed_fields: ClassVar[Dict[str, pydantic.fields.ComputedFieldInfo]] = {}

A dictionary of computed field names and their corresponding ComputedFieldInfo objects.

built_in_models: List[KilnModel] = [KilnModel(family='gpt', name='gpt_4o_mini', friendly_name='GPT 4o Mini', providers=[KilnModelProvider(name=<ModelProviderName.openai: 'openai'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id='gpt-4o-mini-2024-07-18', provider_options={'model': 'gpt-4o-mini'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id=None, provider_options={'model': 'openai/gpt-4o-mini'})], supports_structured_output=True), KilnModel(family='gpt', name='gpt_4o', friendly_name='GPT 4o', providers=[KilnModelProvider(name=<ModelProviderName.openai: 'openai'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id='gpt-4o-2024-08-06', provider_options={'model': 'gpt-4o'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id=None, provider_options={'model': 'openai/gpt-4o-2024-08-06'})], supports_structured_output=True), KilnModel(family='claude', name='claude_3_5_haiku', friendly_name='Claude 3.5 Haiku', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id=None, provider_options={'model': 'anthropic/claude-3-5-haiku'})], supports_structured_output=True), KilnModel(family='claude', name='claude_3_5_sonnet', friendly_name='Claude 3.5 Sonnet', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id=None, provider_options={'model': 'anthropic/claude-3.5-sonnet'})], supports_structured_output=True), KilnModel(family='gemini', name='gemini_1_5_pro', friendly_name='Gemini 1.5 Pro', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=False, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'google/gemini-pro-1.5'})], supports_structured_output=True), KilnModel(family='gemini', name='gemini_1_5_flash', friendly_name='Gemini 1.5 Flash', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'google/gemini-flash-1.5'})], supports_structured_output=True), KilnModel(family='gemini', name='gemini_1_5_flash_8b', friendly_name='Gemini 1.5 Flash 8B', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=False, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'google/gemini-flash-1.5-8b'})], supports_structured_output=True), KilnModel(family='llama', name='nemotron_70b', friendly_name='Nemotron 70B', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=False, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'nvidia/llama-3.1-nemotron-70b-instruct'})], supports_structured_output=True), KilnModel(family='llama', name='llama_3_1_8b', friendly_name='Llama 3.1 8B', providers=[KilnModelProvider(name=<ModelProviderName.groq: 'groq'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id=None, provider_options={'model': 'llama-3.1-8b-instant'}), KilnModelProvider(name=<ModelProviderName.amazon_bedrock: 'amazon_bedrock'>, supports_structured_output=False, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'meta.llama3-1-8b-instruct-v1:0', 'region_name': 'us-west-2'}), KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=True, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'llama3.1:8b', 'model_aliases': ['llama3.1']}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=False, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'meta-llama/llama-3.1-8b-instruct'})], supports_structured_output=True), KilnModel(family='llama', name='llama_3_1_70b', friendly_name='Llama 3.1 70B', providers=[KilnModelProvider(name=<ModelProviderName.groq: 'groq'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id=None, provider_options={'model': 'llama-3.1-70b-versatile'}), KilnModelProvider(name=<ModelProviderName.amazon_bedrock: 'amazon_bedrock'>, supports_structured_output=False, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'meta.llama3-1-70b-instruct-v1:0', 'region_name': 'us-west-2'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id=None, provider_options={'model': 'meta-llama/llama-3.1-70b-instruct'}), KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id=None, provider_options={'model': 'llama3.1:70b'})], supports_structured_output=True), KilnModel(family='llama', name='llama_3_1_405b', friendly_name='Llama 3.1 405B', providers=[KilnModelProvider(name=<ModelProviderName.amazon_bedrock: 'amazon_bedrock'>, supports_structured_output=True, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'meta.llama3-1-405b-instruct-v1:0', 'region_name': 'us-west-2'}), KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id=None, provider_options={'model': 'llama3.1:405b'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id=None, provider_options={'model': 'meta-llama/llama-3.1-405b-instruct'})], supports_structured_output=True), KilnModel(family='mistral', name='mistral_nemo', friendly_name='Mistral Nemo', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id=None, provider_options={'model': 'mistralai/mistral-nemo'})], supports_structured_output=True), KilnModel(family='mistral', name='mistral_large', friendly_name='Mistral Large', providers=[KilnModelProvider(name=<ModelProviderName.amazon_bedrock: 'amazon_bedrock'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id=None, provider_options={'model': 'mistral.mistral-large-2407-v1:0', 'region_name': 'us-west-2'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id=None, provider_options={'model': 'mistralai/mistral-large'}), KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=True, supports_data_gen=True, provider_finetune_id=None, provider_options={'model': 'mistral-large'})], supports_structured_output=True), KilnModel(family='llama', name='llama_3_2_3b', friendly_name='Llama 3.2 3B', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=False, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'meta-llama/llama-3.2-3b-instruct'}), KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=False, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'llama3.2'})], supports_structured_output=True), KilnModel(family='llama', name='llama_3_2_11b', friendly_name='Llama 3.2 11B', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=False, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'meta-llama/llama-3.2-11b-vision-instruct'}), KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=False, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'llama3.2-vision'})], supports_structured_output=True), KilnModel(family='llama', name='llama_3_2_90b', friendly_name='Llama 3.2 90B', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=False, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'meta-llama/llama-3.2-90b-vision-instruct'}), KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=False, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'llama3.2-vision:90b'})], supports_structured_output=True), KilnModel(family='phi', name='phi_3_5', friendly_name='Phi 3.5', providers=[KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=False, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'phi3.5'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=False, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'microsoft/phi-3.5-mini-128k-instruct'})], supports_structured_output=False), KilnModel(family='gemma', name='gemma_2_2b', friendly_name='Gemma 2 2B', providers=[KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=False, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'gemma2:2b'})], supports_structured_output=False), KilnModel(family='gemma', name='gemma_2_9b', friendly_name='Gemma 2 9B', providers=[KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=True, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'gemma2:9b'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'google/gemma-2-9b-it'})], supports_structured_output=False), KilnModel(family='gemma', name='gemma_2_27b', friendly_name='Gemma 2 27B', providers=[KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=True, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'gemma2:27b'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, supports_data_gen=False, provider_finetune_id=None, provider_options={'model': 'google/gemma-2-27b-it'})], supports_structured_output=False)]
def get_model_and_provider( model_name: str, provider_name: str) -> tuple[KilnModel | None, KilnModelProvider | None]:
498def get_model_and_provider(
499    model_name: str, provider_name: str
500) -> tuple[KilnModel | None, KilnModelProvider | None]:
501    model = next(filter(lambda m: m.name == model_name, built_in_models), None)
502    if model is None:
503        return None, None
504    provider = next(filter(lambda p: p.name == provider_name, model.providers), None)
505    # all or nothing
506    if provider is None or model is None:
507        return None, None
508    return model, provider
def provider_name_from_id(id: str) -> str:
511def provider_name_from_id(id: str) -> str:
512    """
513    Converts a provider ID to its human-readable name.
514
515    Args:
516        id: The provider identifier string
517
518    Returns:
519        The human-readable name of the provider
520
521    Raises:
522        ValueError: If the provider ID is invalid or unhandled
523    """
524    if id in ModelProviderName.__members__:
525        enum_id = ModelProviderName(id)
526        match enum_id:
527            case ModelProviderName.amazon_bedrock:
528                return "Amazon Bedrock"
529            case ModelProviderName.openrouter:
530                return "OpenRouter"
531            case ModelProviderName.groq:
532                return "Groq"
533            case ModelProviderName.ollama:
534                return "Ollama"
535            case ModelProviderName.openai:
536                return "OpenAI"
537            case _:
538                # triggers pyright warning if I miss a case
539                raise_exhaustive_error(enum_id)
540
541    return "Unknown provider: " + id

Converts a provider ID to its human-readable name.

Args: id: The provider identifier string

Returns: The human-readable name of the provider

Raises: ValueError: If the provider ID is invalid or unhandled

def raise_exhaustive_error(value: NoReturn) -> NoReturn:
544def raise_exhaustive_error(value: NoReturn) -> NoReturn:
545    raise ValueError(f"Unhandled enum value: {value}")
@dataclass
class ModelProviderWarning:
548@dataclass
549class ModelProviderWarning:
550    required_config_keys: List[str]
551    message: str
ModelProviderWarning(required_config_keys: List[str], message: str)
required_config_keys: List[str]
message: str
provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {<ModelProviderName.amazon_bedrock: 'amazon_bedrock'>: ModelProviderWarning(required_config_keys=['bedrock_access_key', 'bedrock_secret_key'], message='Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview'), <ModelProviderName.openrouter: 'openrouter'>: ModelProviderWarning(required_config_keys=['open_router_api_key'], message='Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys'), <ModelProviderName.groq: 'groq'>: ModelProviderWarning(required_config_keys=['groq_api_key'], message='Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys'), <ModelProviderName.openai: 'openai'>: ModelProviderWarning(required_config_keys=['open_ai_api_key'], message='Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys')}
async def provider_enabled(provider_name: ModelProviderName) -> bool:
574async def provider_enabled(provider_name: ModelProviderName) -> bool:
575    if provider_name == ModelProviderName.ollama:
576        try:
577            tags = await get_ollama_connection()
578            return tags is not None and len(tags.models) > 0
579        except Exception:
580            return False
581
582    provider_warning = provider_warnings.get(provider_name)
583    if provider_warning is None:
584        return False
585    for required_key in provider_warning.required_config_keys:
586        if get_config_value(required_key) is None:
587            return False
588    return True
def get_config_value(key: str):
591def get_config_value(key: str):
592    try:
593        return Config.shared().__getattr__(key)
594    except AttributeError:
595        return None
def check_provider_warnings(provider_name: ModelProviderName):
598def check_provider_warnings(provider_name: ModelProviderName):
599    """
600    Validates that required configuration is present for a given provider.
601
602    Args:
603        provider_name: The provider to check
604
605    Raises:
606        ValueError: If required configuration keys are missing
607    """
608    warning_check = provider_warnings.get(provider_name)
609    if warning_check is None:
610        return
611    for key in warning_check.required_config_keys:
612        if get_config_value(key) is None:
613            raise ValueError(warning_check.message)

Validates that required configuration is present for a given provider.

Args: provider_name: The provider to check

Raises: ValueError: If required configuration keys are missing

async def langchain_model_from( name: str, provider_name: str | None = None) -> langchain_core.language_models.chat_models.BaseChatModel:
616async def langchain_model_from(
617    name: str, provider_name: str | None = None
618) -> BaseChatModel:
619    """
620    Creates a LangChain chat model instance for the specified model and provider.
621
622    Args:
623        name: The name of the model to instantiate
624        provider_name: Optional specific provider to use (defaults to first available)
625
626    Returns:
627        A configured LangChain chat model instance
628
629    Raises:
630        ValueError: If the model/provider combination is invalid or misconfigured
631    """
632    if name not in ModelName.__members__:
633        raise ValueError(f"Invalid name: {name}")
634
635    # Select the model from built_in_models using the name
636    model = next(filter(lambda m: m.name == name, built_in_models))
637    if model is None:
638        raise ValueError(f"Model {name} not found")
639
640    # If a provider is provided, select the provider from the model's provider_config
641    provider: KilnModelProvider | None = None
642    if model.providers is None or len(model.providers) == 0:
643        raise ValueError(f"Model {name} has no providers")
644    elif provider_name is None:
645        # TODO: priority order
646        provider = model.providers[0]
647    else:
648        provider = next(
649            filter(lambda p: p.name == provider_name, model.providers), None
650        )
651    if provider is None:
652        raise ValueError(f"Provider {provider_name} not found for model {name}")
653
654    check_provider_warnings(provider.name)
655
656    if provider.name == ModelProviderName.openai:
657        api_key = Config.shared().open_ai_api_key
658        return ChatOpenAI(**provider.provider_options, openai_api_key=api_key)  # type: ignore[arg-type]
659    elif provider.name == ModelProviderName.groq:
660        api_key = Config.shared().groq_api_key
661        if api_key is None:
662            raise ValueError(
663                "Attempted to use Groq without an API key set. "
664                "Get your API key from https://console.groq.com/keys"
665            )
666        return ChatGroq(**provider.provider_options, groq_api_key=api_key)  # type: ignore[arg-type]
667    elif provider.name == ModelProviderName.amazon_bedrock:
668        api_key = Config.shared().bedrock_access_key
669        secret_key = Config.shared().bedrock_secret_key
670        # langchain doesn't allow passing these, so ugly hack to set env vars
671        os.environ["AWS_ACCESS_KEY_ID"] = api_key
672        os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
673        return ChatBedrockConverse(
674            **provider.provider_options,
675        )
676    elif provider.name == ModelProviderName.ollama:
677        # Ollama model naming is pretty flexible. We try a few versions of the model name
678        potential_model_names = []
679        if "model" in provider.provider_options:
680            potential_model_names.append(provider.provider_options["model"])
681        if "model_aliases" in provider.provider_options:
682            potential_model_names.extend(provider.provider_options["model_aliases"])
683
684        # Get the list of models Ollama supports
685        ollama_connection = await get_ollama_connection()
686        if ollama_connection is None:
687            raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
688
689        for model_name in potential_model_names:
690            if ollama_model_supported(ollama_connection, model_name):
691                return ChatOllama(model=model_name, base_url=ollama_base_url())
692
693        raise ValueError(f"Model {name} not installed on Ollama")
694    elif provider.name == ModelProviderName.openrouter:
695        api_key = Config.shared().open_router_api_key
696        base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
697        return ChatOpenAI(
698            **provider.provider_options,
699            openai_api_key=api_key,  # type: ignore[arg-type]
700            openai_api_base=base_url,  # type: ignore[arg-type]
701            default_headers={
702                "HTTP-Referer": "https://getkiln.ai/openrouter",
703                "X-Title": "KilnAI",
704            },
705        )
706    else:
707        raise ValueError(f"Invalid model or provider: {name} - {provider_name}")

Creates a LangChain chat model instance for the specified model and provider.

Args: name: The name of the model to instantiate provider_name: Optional specific provider to use (defaults to first available)

Returns: A configured LangChain chat model instance

Raises: ValueError: If the model/provider combination is invalid or misconfigured

def ollama_base_url() -> str:
710def ollama_base_url() -> str:
711    """
712    Gets the base URL for Ollama API connections.
713
714    Returns:
715        The base URL to use for Ollama API calls, using environment variable if set
716        or falling back to localhost default
717    """
718    env_base_url = os.getenv("OLLAMA_BASE_URL")
719    if env_base_url is not None:
720        return env_base_url
721    return "http://localhost:11434"

Gets the base URL for Ollama API connections.

Returns: The base URL to use for Ollama API calls, using environment variable if set or falling back to localhost default

async def ollama_online() -> bool:
724async def ollama_online() -> bool:
725    """
726    Checks if the Ollama service is available and responding.
727
728    Returns:
729        True if Ollama is available and responding, False otherwise
730    """
731    try:
732        httpx.get(ollama_base_url() + "/api/tags")
733    except httpx.RequestError:
734        return False
735    return True

Checks if the Ollama service is available and responding.

Returns: True if Ollama is available and responding, False otherwise

class OllamaConnection(pydantic.main.BaseModel):
738class OllamaConnection(BaseModel):
739    message: str
740    models: List[str]

Usage docs: https://docs.pydantic.dev/2.9/concepts/models/

A base class for creating Pydantic models.

Attributes: __class_vars__: The names of the class variables defined on the model. __private_attributes__: Metadata about the private attributes of the model. __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.

__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The core schema of the model.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
    This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
    __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a [`RootModel`][pydantic.root_model.RootModel].
__pydantic_serializer__: The `pydantic-core` `SchemaSerializer` used to dump instances of the model.
__pydantic_validator__: The `pydantic-core` `SchemaValidator` used to validate instances of the model.

__pydantic_extra__: A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra]
    is set to `'allow'`.
__pydantic_fields_set__: The names of fields explicitly set during instantiation.
__pydantic_private__: Values of private attributes set on the model instance.
message: str
models: List[str]
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_fields: ClassVar[Dict[str, pydantic.fields.FieldInfo]] = {'message': FieldInfo(annotation=str, required=True), 'models': FieldInfo(annotation=List[str], required=True)}

Metadata about the fields defined on the model, mapping of field names to [FieldInfo][pydantic.fields.FieldInfo] objects.

This replaces Model.__fields__ from Pydantic V1.

model_computed_fields: ClassVar[Dict[str, pydantic.fields.ComputedFieldInfo]] = {}

A dictionary of computed field names and their corresponding ComputedFieldInfo objects.

def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
744def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
745    # Build a list of models we support for Ollama from the built-in model list
746    supported_ollama_models = [
747        provider.provider_options["model"]
748        for model in built_in_models
749        for provider in model.providers
750        if provider.name == ModelProviderName.ollama
751    ]
752    # Append model_aliases to supported_ollama_models
753    supported_ollama_models.extend(
754        [
755            alias
756            for model in built_in_models
757            for provider in model.providers
758            for alias in provider.provider_options.get("model_aliases", [])
759        ]
760    )
761
762    if "models" in tags:
763        models = tags["models"]
764        if isinstance(models, list):
765            model_names = [model["model"] for model in models]
766            available_supported_models = [
767                model
768                for model in model_names
769                if model in supported_ollama_models
770                or model in [f"{m}:latest" for m in supported_ollama_models]
771            ]
772            if available_supported_models:
773                return OllamaConnection(
774                    message="Ollama connected",
775                    models=available_supported_models,
776                )
777
778    return OllamaConnection(
779        message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.",
780        models=[],
781    )
async def get_ollama_connection() -> OllamaConnection | None:
784async def get_ollama_connection() -> OllamaConnection | None:
785    """
786    Gets the connection status for Ollama.
787    """
788    try:
789        tags = requests.get(ollama_base_url() + "/api/tags", timeout=5).json()
790
791    except Exception:
792        return None
793
794    return parse_ollama_tags(tags)

Gets the connection status for Ollama.

def ollama_model_supported( conn: OllamaConnection, model_name: str) -> bool:
797def ollama_model_supported(conn: OllamaConnection, model_name: str) -> bool:
798    return model_name in conn.models or f"{model_name}:latest" in conn.models