kiln_ai.adapters.langchain_adapters

  1import os
  2from os import getenv
  3from typing import Any, Dict
  4
  5from langchain_aws import ChatBedrockConverse
  6from langchain_core.language_models import LanguageModelInput
  7from langchain_core.language_models.chat_models import BaseChatModel
  8from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
  9from langchain_core.messages.base import BaseMessage
 10from langchain_core.runnables import Runnable
 11from langchain_fireworks import ChatFireworks
 12from langchain_groq import ChatGroq
 13from langchain_ollama import ChatOllama
 14from langchain_openai import ChatOpenAI
 15from pydantic import BaseModel
 16
 17import kiln_ai.datamodel as datamodel
 18from kiln_ai.adapters.ollama_tools import (
 19    get_ollama_connection,
 20    ollama_base_url,
 21    ollama_model_installed,
 22)
 23from kiln_ai.utils.config import Config
 24
 25from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput
 26from .ml_model_list import KilnModelProvider, ModelProviderName
 27from .provider_tools import kiln_model_provider_from
 28
 29LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
 30
 31
 32class LangchainAdapter(BaseAdapter):
 33    _model: LangChainModelType | None = None
 34
 35    def __init__(
 36        self,
 37        kiln_task: datamodel.Task,
 38        custom_model: BaseChatModel | None = None,
 39        model_name: str | None = None,
 40        provider: str | None = None,
 41        prompt_builder: BasePromptBuilder | None = None,
 42        tags: list[str] | None = None,
 43    ):
 44        super().__init__(kiln_task, prompt_builder=prompt_builder, tags=tags)
 45        if custom_model is not None:
 46            self._model = custom_model
 47
 48            # Attempt to infer model provider and name from custom model
 49            self.model_provider = "custom.langchain:" + custom_model.__class__.__name__
 50            self.model_name = "custom.langchain:unknown_model"
 51            if hasattr(custom_model, "model_name") and isinstance(
 52                getattr(custom_model, "model_name"), str
 53            ):
 54                self.model_name = "custom.langchain:" + getattr(
 55                    custom_model, "model_name"
 56                )
 57            if hasattr(custom_model, "model") and isinstance(
 58                getattr(custom_model, "model"), str
 59            ):
 60                self.model_name = "custom.langchain:" + getattr(custom_model, "model")
 61        elif model_name is not None:
 62            self.model_name = model_name
 63            self.model_provider = provider or "custom.langchain.default_provider"
 64        else:
 65            raise ValueError(
 66                "model_name and provider must be provided if custom_model is not provided"
 67            )
 68
 69    async def model(self) -> LangChainModelType:
 70        # cached model
 71        if self._model:
 72            return self._model
 73
 74        self._model = await langchain_model_from(self.model_name, self.model_provider)
 75
 76        if self.has_structured_output():
 77            if not hasattr(self._model, "with_structured_output") or not callable(
 78                getattr(self._model, "with_structured_output")
 79            ):
 80                raise ValueError(
 81                    f"model {self._model} does not support structured output, cannot use output_json_schema"
 82                )
 83            # Langchain expects title/description to be at top level, on top of json schema
 84            output_schema = self.kiln_task.output_schema()
 85            if output_schema is None:
 86                raise ValueError(
 87                    f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}"
 88                )
 89            output_schema["title"] = "task_response"
 90            output_schema["description"] = "A response from the task"
 91            with_structured_output_options = await get_structured_output_options(
 92                self.model_name, self.model_provider
 93            )
 94            self._model = self._model.with_structured_output(
 95                output_schema,
 96                include_raw=True,
 97                **with_structured_output_options,
 98            )
 99        return self._model
100
101    async def _run(self, input: Dict | str) -> RunOutput:
102        model = await self.model()
103        chain = model
104        intermediate_outputs = {}
105
106        prompt = self.build_prompt()
107        user_msg = self.prompt_builder.build_user_message(input)
108        messages = [
109            SystemMessage(content=prompt),
110            HumanMessage(content=user_msg),
111        ]
112
113        # COT with structured output
114        cot_prompt = self.prompt_builder.chain_of_thought_prompt()
115        if cot_prompt and self.has_structured_output():
116            # Base model (without structured output) used for COT message
117            base_model = await langchain_model_from(
118                self.model_name, self.model_provider
119            )
120            messages.append(
121                SystemMessage(content=cot_prompt),
122            )
123
124            cot_messages = [*messages]
125            cot_response = await base_model.ainvoke(cot_messages)
126            intermediate_outputs["chain_of_thought"] = cot_response.content
127            messages.append(AIMessage(content=cot_response.content))
128            messages.append(
129                SystemMessage(content="Considering the above, return a final result.")
130            )
131        elif cot_prompt:
132            messages.append(SystemMessage(content=cot_prompt))
133
134        response = await chain.ainvoke(messages)
135
136        if self.has_structured_output():
137            if (
138                not isinstance(response, dict)
139                or "parsed" not in response
140                or not isinstance(response["parsed"], dict)
141            ):
142                raise RuntimeError(f"structured response not returned: {response}")
143            structured_response = response["parsed"]
144            return RunOutput(
145                output=self._munge_response(structured_response),
146                intermediate_outputs=intermediate_outputs,
147            )
148        else:
149            if not isinstance(response, BaseMessage):
150                raise RuntimeError(f"response is not a BaseMessage: {response}")
151            text_content = response.content
152            if not isinstance(text_content, str):
153                raise RuntimeError(f"response is not a string: {text_content}")
154            return RunOutput(
155                output=text_content,
156                intermediate_outputs=intermediate_outputs,
157            )
158
159    def adapter_info(self) -> AdapterInfo:
160        return AdapterInfo(
161            model_name=self.model_name,
162            model_provider=self.model_provider,
163            adapter_name="kiln_langchain_adapter",
164            prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
165            prompt_id=self.prompt_builder.prompt_id(),
166        )
167
168    def _munge_response(self, response: Dict) -> Dict:
169        # Mistral Large tool calling format is a bit different. Convert to standard format.
170        if (
171            "name" in response
172            and response["name"] == "task_response"
173            and "arguments" in response
174        ):
175            return response["arguments"]
176        return response
177
178
179async def get_structured_output_options(
180    model_name: str, model_provider: str
181) -> Dict[str, Any]:
182    finetune_provider = await kiln_model_provider_from(model_name, model_provider)
183    if finetune_provider and finetune_provider.adapter_options.get("langchain"):
184        return finetune_provider.adapter_options["langchain"].get(
185            "with_structured_output_options", {}
186        )
187    return {}
188
189
190async def langchain_model_from(
191    name: str, provider_name: str | None = None
192) -> BaseChatModel:
193    provider = await kiln_model_provider_from(name, provider_name)
194    return await langchain_model_from_provider(provider, name)
195
196
197async def langchain_model_from_provider(
198    provider: KilnModelProvider, model_name: str
199) -> BaseChatModel:
200    if provider.name == ModelProviderName.openai:
201        api_key = Config.shared().open_ai_api_key
202        return ChatOpenAI(**provider.provider_options, openai_api_key=api_key)  # type: ignore[arg-type]
203    elif provider.name == ModelProviderName.openai_compatible:
204        # See provider_tools.py for how base_url, key and other parameters are set
205        return ChatOpenAI(**provider.provider_options)  # type: ignore[arg-type]
206    elif provider.name == ModelProviderName.groq:
207        api_key = Config.shared().groq_api_key
208        if api_key is None:
209            raise ValueError(
210                "Attempted to use Groq without an API key set. "
211                "Get your API key from https://console.groq.com/keys"
212            )
213        return ChatGroq(**provider.provider_options, groq_api_key=api_key)  # type: ignore[arg-type]
214    elif provider.name == ModelProviderName.amazon_bedrock:
215        api_key = Config.shared().bedrock_access_key
216        secret_key = Config.shared().bedrock_secret_key
217        # langchain doesn't allow passing these, so ugly hack to set env vars
218        os.environ["AWS_ACCESS_KEY_ID"] = api_key
219        os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
220        return ChatBedrockConverse(
221            **provider.provider_options,
222        )
223    elif provider.name == ModelProviderName.fireworks_ai:
224        api_key = Config.shared().fireworks_api_key
225        return ChatFireworks(**provider.provider_options, api_key=api_key)
226    elif provider.name == ModelProviderName.ollama:
227        # Ollama model naming is pretty flexible. We try a few versions of the model name
228        potential_model_names = []
229        if "model" in provider.provider_options:
230            potential_model_names.append(provider.provider_options["model"])
231        if "model_aliases" in provider.provider_options:
232            potential_model_names.extend(provider.provider_options["model_aliases"])
233
234        # Get the list of models Ollama supports
235        ollama_connection = await get_ollama_connection()
236        if ollama_connection is None:
237            raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
238
239        for model_name in potential_model_names:
240            if ollama_model_installed(ollama_connection, model_name):
241                return ChatOllama(model=model_name, base_url=ollama_base_url())
242
243        raise ValueError(f"Model {model_name} not installed on Ollama")
244    elif provider.name == ModelProviderName.openrouter:
245        api_key = Config.shared().open_router_api_key
246        base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
247        return ChatOpenAI(
248            **provider.provider_options,
249            openai_api_key=api_key,  # type: ignore[arg-type]
250            openai_api_base=base_url,  # type: ignore[arg-type]
251            default_headers={
252                "HTTP-Referer": "https://getkiln.ai/openrouter",
253                "X-Title": "KilnAI",
254            },
255        )
256    else:
257        raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}")
LangChainModelType = typing.Union[langchain_core.language_models.chat_models.BaseChatModel, langchain_core.runnables.base.Runnable[typing.Union[langchain_core.prompt_values.PromptValue, str, collections.abc.Sequence[typing.Union[langchain_core.messages.base.BaseMessage, list[str], tuple[str, str], str, dict[str, typing.Any]]]], typing.Union[typing.Dict, pydantic.main.BaseModel]]]
class LangchainAdapter(kiln_ai.adapters.base_adapter.BaseAdapter):
 33class LangchainAdapter(BaseAdapter):
 34    _model: LangChainModelType | None = None
 35
 36    def __init__(
 37        self,
 38        kiln_task: datamodel.Task,
 39        custom_model: BaseChatModel | None = None,
 40        model_name: str | None = None,
 41        provider: str | None = None,
 42        prompt_builder: BasePromptBuilder | None = None,
 43        tags: list[str] | None = None,
 44    ):
 45        super().__init__(kiln_task, prompt_builder=prompt_builder, tags=tags)
 46        if custom_model is not None:
 47            self._model = custom_model
 48
 49            # Attempt to infer model provider and name from custom model
 50            self.model_provider = "custom.langchain:" + custom_model.__class__.__name__
 51            self.model_name = "custom.langchain:unknown_model"
 52            if hasattr(custom_model, "model_name") and isinstance(
 53                getattr(custom_model, "model_name"), str
 54            ):
 55                self.model_name = "custom.langchain:" + getattr(
 56                    custom_model, "model_name"
 57                )
 58            if hasattr(custom_model, "model") and isinstance(
 59                getattr(custom_model, "model"), str
 60            ):
 61                self.model_name = "custom.langchain:" + getattr(custom_model, "model")
 62        elif model_name is not None:
 63            self.model_name = model_name
 64            self.model_provider = provider or "custom.langchain.default_provider"
 65        else:
 66            raise ValueError(
 67                "model_name and provider must be provided if custom_model is not provided"
 68            )
 69
 70    async def model(self) -> LangChainModelType:
 71        # cached model
 72        if self._model:
 73            return self._model
 74
 75        self._model = await langchain_model_from(self.model_name, self.model_provider)
 76
 77        if self.has_structured_output():
 78            if not hasattr(self._model, "with_structured_output") or not callable(
 79                getattr(self._model, "with_structured_output")
 80            ):
 81                raise ValueError(
 82                    f"model {self._model} does not support structured output, cannot use output_json_schema"
 83                )
 84            # Langchain expects title/description to be at top level, on top of json schema
 85            output_schema = self.kiln_task.output_schema()
 86            if output_schema is None:
 87                raise ValueError(
 88                    f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}"
 89                )
 90            output_schema["title"] = "task_response"
 91            output_schema["description"] = "A response from the task"
 92            with_structured_output_options = await get_structured_output_options(
 93                self.model_name, self.model_provider
 94            )
 95            self._model = self._model.with_structured_output(
 96                output_schema,
 97                include_raw=True,
 98                **with_structured_output_options,
 99            )
100        return self._model
101
102    async def _run(self, input: Dict | str) -> RunOutput:
103        model = await self.model()
104        chain = model
105        intermediate_outputs = {}
106
107        prompt = self.build_prompt()
108        user_msg = self.prompt_builder.build_user_message(input)
109        messages = [
110            SystemMessage(content=prompt),
111            HumanMessage(content=user_msg),
112        ]
113
114        # COT with structured output
115        cot_prompt = self.prompt_builder.chain_of_thought_prompt()
116        if cot_prompt and self.has_structured_output():
117            # Base model (without structured output) used for COT message
118            base_model = await langchain_model_from(
119                self.model_name, self.model_provider
120            )
121            messages.append(
122                SystemMessage(content=cot_prompt),
123            )
124
125            cot_messages = [*messages]
126            cot_response = await base_model.ainvoke(cot_messages)
127            intermediate_outputs["chain_of_thought"] = cot_response.content
128            messages.append(AIMessage(content=cot_response.content))
129            messages.append(
130                SystemMessage(content="Considering the above, return a final result.")
131            )
132        elif cot_prompt:
133            messages.append(SystemMessage(content=cot_prompt))
134
135        response = await chain.ainvoke(messages)
136
137        if self.has_structured_output():
138            if (
139                not isinstance(response, dict)
140                or "parsed" not in response
141                or not isinstance(response["parsed"], dict)
142            ):
143                raise RuntimeError(f"structured response not returned: {response}")
144            structured_response = response["parsed"]
145            return RunOutput(
146                output=self._munge_response(structured_response),
147                intermediate_outputs=intermediate_outputs,
148            )
149        else:
150            if not isinstance(response, BaseMessage):
151                raise RuntimeError(f"response is not a BaseMessage: {response}")
152            text_content = response.content
153            if not isinstance(text_content, str):
154                raise RuntimeError(f"response is not a string: {text_content}")
155            return RunOutput(
156                output=text_content,
157                intermediate_outputs=intermediate_outputs,
158            )
159
160    def adapter_info(self) -> AdapterInfo:
161        return AdapterInfo(
162            model_name=self.model_name,
163            model_provider=self.model_provider,
164            adapter_name="kiln_langchain_adapter",
165            prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
166            prompt_id=self.prompt_builder.prompt_id(),
167        )
168
169    def _munge_response(self, response: Dict) -> Dict:
170        # Mistral Large tool calling format is a bit different. Convert to standard format.
171        if (
172            "name" in response
173            and response["name"] == "task_response"
174            and "arguments" in response
175        ):
176            return response["arguments"]
177        return response

Base class for AI model adapters that handle task execution.

This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.

Attributes: prompt_builder (BasePromptBuilder): Builder for constructing prompts for the model kiln_task (Task): The task configuration and metadata output_schema (dict | None): JSON schema for validating structured outputs input_schema (dict | None): JSON schema for validating structured inputs

LangchainAdapter( kiln_task: kiln_ai.datamodel.Task, custom_model: langchain_core.language_models.chat_models.BaseChatModel | None = None, model_name: str | None = None, provider: str | None = None, prompt_builder: kiln_ai.adapters.prompt_builders.BasePromptBuilder | None = None, tags: list[str] | None = None)
36    def __init__(
37        self,
38        kiln_task: datamodel.Task,
39        custom_model: BaseChatModel | None = None,
40        model_name: str | None = None,
41        provider: str | None = None,
42        prompt_builder: BasePromptBuilder | None = None,
43        tags: list[str] | None = None,
44    ):
45        super().__init__(kiln_task, prompt_builder=prompt_builder, tags=tags)
46        if custom_model is not None:
47            self._model = custom_model
48
49            # Attempt to infer model provider and name from custom model
50            self.model_provider = "custom.langchain:" + custom_model.__class__.__name__
51            self.model_name = "custom.langchain:unknown_model"
52            if hasattr(custom_model, "model_name") and isinstance(
53                getattr(custom_model, "model_name"), str
54            ):
55                self.model_name = "custom.langchain:" + getattr(
56                    custom_model, "model_name"
57                )
58            if hasattr(custom_model, "model") and isinstance(
59                getattr(custom_model, "model"), str
60            ):
61                self.model_name = "custom.langchain:" + getattr(custom_model, "model")
62        elif model_name is not None:
63            self.model_name = model_name
64            self.model_provider = provider or "custom.langchain.default_provider"
65        else:
66            raise ValueError(
67                "model_name and provider must be provided if custom_model is not provided"
68            )
async def model( self) -> Union[langchain_core.language_models.chat_models.BaseChatModel, langchain_core.runnables.base.Runnable[Union[langchain_core.prompt_values.PromptValue, str, Sequence[Union[langchain_core.messages.base.BaseMessage, list[str], tuple[str, str], str, dict[str, Any]]]], Union[Dict, pydantic.main.BaseModel]]]:
 70    async def model(self) -> LangChainModelType:
 71        # cached model
 72        if self._model:
 73            return self._model
 74
 75        self._model = await langchain_model_from(self.model_name, self.model_provider)
 76
 77        if self.has_structured_output():
 78            if not hasattr(self._model, "with_structured_output") or not callable(
 79                getattr(self._model, "with_structured_output")
 80            ):
 81                raise ValueError(
 82                    f"model {self._model} does not support structured output, cannot use output_json_schema"
 83                )
 84            # Langchain expects title/description to be at top level, on top of json schema
 85            output_schema = self.kiln_task.output_schema()
 86            if output_schema is None:
 87                raise ValueError(
 88                    f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}"
 89                )
 90            output_schema["title"] = "task_response"
 91            output_schema["description"] = "A response from the task"
 92            with_structured_output_options = await get_structured_output_options(
 93                self.model_name, self.model_provider
 94            )
 95            self._model = self._model.with_structured_output(
 96                output_schema,
 97                include_raw=True,
 98                **with_structured_output_options,
 99            )
100        return self._model
def adapter_info(self) -> kiln_ai.adapters.base_adapter.AdapterInfo:
160    def adapter_info(self) -> AdapterInfo:
161        return AdapterInfo(
162            model_name=self.model_name,
163            model_provider=self.model_provider,
164            adapter_name="kiln_langchain_adapter",
165            prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
166            prompt_id=self.prompt_builder.prompt_id(),
167        )
async def get_structured_output_options(model_name: str, model_provider: str) -> Dict[str, Any]:
180async def get_structured_output_options(
181    model_name: str, model_provider: str
182) -> Dict[str, Any]:
183    finetune_provider = await kiln_model_provider_from(model_name, model_provider)
184    if finetune_provider and finetune_provider.adapter_options.get("langchain"):
185        return finetune_provider.adapter_options["langchain"].get(
186            "with_structured_output_options", {}
187        )
188    return {}
async def langchain_model_from( name: str, provider_name: str | None = None) -> langchain_core.language_models.chat_models.BaseChatModel:
191async def langchain_model_from(
192    name: str, provider_name: str | None = None
193) -> BaseChatModel:
194    provider = await kiln_model_provider_from(name, provider_name)
195    return await langchain_model_from_provider(provider, name)
async def langchain_model_from_provider( provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, model_name: str) -> langchain_core.language_models.chat_models.BaseChatModel:
198async def langchain_model_from_provider(
199    provider: KilnModelProvider, model_name: str
200) -> BaseChatModel:
201    if provider.name == ModelProviderName.openai:
202        api_key = Config.shared().open_ai_api_key
203        return ChatOpenAI(**provider.provider_options, openai_api_key=api_key)  # type: ignore[arg-type]
204    elif provider.name == ModelProviderName.openai_compatible:
205        # See provider_tools.py for how base_url, key and other parameters are set
206        return ChatOpenAI(**provider.provider_options)  # type: ignore[arg-type]
207    elif provider.name == ModelProviderName.groq:
208        api_key = Config.shared().groq_api_key
209        if api_key is None:
210            raise ValueError(
211                "Attempted to use Groq without an API key set. "
212                "Get your API key from https://console.groq.com/keys"
213            )
214        return ChatGroq(**provider.provider_options, groq_api_key=api_key)  # type: ignore[arg-type]
215    elif provider.name == ModelProviderName.amazon_bedrock:
216        api_key = Config.shared().bedrock_access_key
217        secret_key = Config.shared().bedrock_secret_key
218        # langchain doesn't allow passing these, so ugly hack to set env vars
219        os.environ["AWS_ACCESS_KEY_ID"] = api_key
220        os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
221        return ChatBedrockConverse(
222            **provider.provider_options,
223        )
224    elif provider.name == ModelProviderName.fireworks_ai:
225        api_key = Config.shared().fireworks_api_key
226        return ChatFireworks(**provider.provider_options, api_key=api_key)
227    elif provider.name == ModelProviderName.ollama:
228        # Ollama model naming is pretty flexible. We try a few versions of the model name
229        potential_model_names = []
230        if "model" in provider.provider_options:
231            potential_model_names.append(provider.provider_options["model"])
232        if "model_aliases" in provider.provider_options:
233            potential_model_names.extend(provider.provider_options["model_aliases"])
234
235        # Get the list of models Ollama supports
236        ollama_connection = await get_ollama_connection()
237        if ollama_connection is None:
238            raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
239
240        for model_name in potential_model_names:
241            if ollama_model_installed(ollama_connection, model_name):
242                return ChatOllama(model=model_name, base_url=ollama_base_url())
243
244        raise ValueError(f"Model {model_name} not installed on Ollama")
245    elif provider.name == ModelProviderName.openrouter:
246        api_key = Config.shared().open_router_api_key
247        base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
248        return ChatOpenAI(
249            **provider.provider_options,
250            openai_api_key=api_key,  # type: ignore[arg-type]
251            openai_api_base=base_url,  # type: ignore[arg-type]
252            default_headers={
253                "HTTP-Referer": "https://getkiln.ai/openrouter",
254                "X-Title": "KilnAI",
255            },
256        )
257    else:
258        raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}")