kiln_ai.adapters.langchain_adapters

  1from typing import Dict
  2
  3from langchain_core.language_models import LanguageModelInput
  4from langchain_core.language_models.chat_models import BaseChatModel
  5from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
  6from langchain_core.messages.base import BaseMessage
  7from langchain_core.runnables import Runnable
  8from pydantic import BaseModel
  9
 10import kiln_ai.datamodel as datamodel
 11
 12from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput
 13from .ml_model_list import langchain_model_from
 14
 15LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
 16
 17
 18class LangChainPromptAdapter(BaseAdapter):
 19    _model: LangChainModelType | None = None
 20
 21    def __init__(
 22        self,
 23        kiln_task: datamodel.Task,
 24        custom_model: BaseChatModel | None = None,
 25        model_name: str | None = None,
 26        provider: str | None = None,
 27        prompt_builder: BasePromptBuilder | None = None,
 28    ):
 29        super().__init__(kiln_task, prompt_builder=prompt_builder)
 30        if custom_model is not None:
 31            self._model = custom_model
 32
 33            # Attempt to infer model provider and name from custom model
 34            self.model_provider = "custom.langchain:" + custom_model.__class__.__name__
 35            self.model_name = "custom.langchain:unknown_model"
 36            if hasattr(custom_model, "model_name") and isinstance(
 37                getattr(custom_model, "model_name"), str
 38            ):
 39                self.model_name = "custom.langchain:" + getattr(
 40                    custom_model, "model_name"
 41                )
 42            if hasattr(custom_model, "model") and isinstance(
 43                getattr(custom_model, "model"), str
 44            ):
 45                self.model_name = "custom.langchain:" + getattr(custom_model, "model")
 46        elif model_name is not None:
 47            self.model_name = model_name
 48            self.model_provider = provider or "custom.langchain.default_provider"
 49        else:
 50            raise ValueError(
 51                "model_name and provider must be provided if custom_model is not provided"
 52            )
 53
 54    def adapter_specific_instructions(self) -> str | None:
 55        if self.has_structured_output():
 56            return "Always respond with a tool call. Never respond with a human readable message."
 57        return None
 58
 59    async def model(self) -> LangChainModelType:
 60        # cached model
 61        if self._model:
 62            return self._model
 63
 64        self._model = await langchain_model_from(self.model_name, self.model_provider)
 65
 66        if self.has_structured_output():
 67            if not hasattr(self._model, "with_structured_output") or not callable(
 68                getattr(self._model, "with_structured_output")
 69            ):
 70                raise ValueError(
 71                    f"model {self._model} does not support structured output, cannot use output_json_schema"
 72                )
 73            # Langchain expects title/description to be at top level, on top of json schema
 74            output_schema = self.kiln_task.output_schema()
 75            if output_schema is None:
 76                raise ValueError(
 77                    f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}"
 78                )
 79            output_schema["title"] = "task_response"
 80            output_schema["description"] = "A response from the task"
 81            self._model = self._model.with_structured_output(
 82                output_schema, include_raw=True
 83            )
 84        return self._model
 85
 86    async def _run(self, input: Dict | str) -> RunOutput:
 87        model = await self.model()
 88        chain = model
 89        intermediate_outputs = {}
 90
 91        prompt = self.build_prompt()
 92        user_msg = self.prompt_builder.build_user_message(input)
 93        messages = [
 94            SystemMessage(content=prompt),
 95            HumanMessage(content=user_msg),
 96        ]
 97
 98        # COT with structured output
 99        cot_prompt = self.prompt_builder.chain_of_thought_prompt()
100        if cot_prompt and self.has_structured_output():
101            # Base model (without structured output) used for COT message
102            base_model = await langchain_model_from(
103                self.model_name, self.model_provider
104            )
105            messages.append(
106                SystemMessage(content=cot_prompt),
107            )
108
109            cot_messages = [*messages]
110            cot_response = base_model.invoke(cot_messages)
111            intermediate_outputs["chain_of_thought"] = cot_response.content
112            messages.append(AIMessage(content=cot_response.content))
113            messages.append(
114                SystemMessage(content="Considering the above, return a final result.")
115            )
116        elif cot_prompt:
117            # for plaintext output, we just add COT instructions. We still only make one call.
118            messages.append(SystemMessage(content=cot_prompt))
119
120        response = chain.invoke(messages)
121
122        if self.has_structured_output():
123            if (
124                not isinstance(response, dict)
125                or "parsed" not in response
126                or not isinstance(response["parsed"], dict)
127            ):
128                raise RuntimeError(f"structured response not returned: {response}")
129            structured_response = response["parsed"]
130            return RunOutput(
131                output=self._munge_response(structured_response),
132                intermediate_outputs=intermediate_outputs,
133            )
134        else:
135            if not isinstance(response, BaseMessage):
136                raise RuntimeError(f"response is not a BaseMessage: {response}")
137            text_content = response.content
138            if not isinstance(text_content, str):
139                raise RuntimeError(f"response is not a string: {text_content}")
140            return RunOutput(
141                output=text_content,
142                intermediate_outputs=intermediate_outputs,
143            )
144
145    def adapter_info(self) -> AdapterInfo:
146        return AdapterInfo(
147            model_name=self.model_name,
148            model_provider=self.model_provider,
149            adapter_name="kiln_langchain_adapter",
150            prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
151        )
152
153    def _munge_response(self, response: Dict) -> Dict:
154        # Mistral Large tool calling format is a bit different. Convert to standard format.
155        if (
156            "name" in response
157            and response["name"] == "task_response"
158            and "arguments" in response
159        ):
160            return response["arguments"]
161        return response
LangChainModelType = typing.Union[langchain_core.language_models.chat_models.BaseChatModel, langchain_core.runnables.base.Runnable[typing.Union[langchain_core.prompt_values.PromptValue, str, collections.abc.Sequence[typing.Union[langchain_core.messages.base.BaseMessage, list[str], tuple[str, str], str, dict[str, typing.Any]]]], typing.Union[typing.Dict, pydantic.main.BaseModel]]]
class LangChainPromptAdapter(kiln_ai.adapters.base_adapter.BaseAdapter):
 19class LangChainPromptAdapter(BaseAdapter):
 20    _model: LangChainModelType | None = None
 21
 22    def __init__(
 23        self,
 24        kiln_task: datamodel.Task,
 25        custom_model: BaseChatModel | None = None,
 26        model_name: str | None = None,
 27        provider: str | None = None,
 28        prompt_builder: BasePromptBuilder | None = None,
 29    ):
 30        super().__init__(kiln_task, prompt_builder=prompt_builder)
 31        if custom_model is not None:
 32            self._model = custom_model
 33
 34            # Attempt to infer model provider and name from custom model
 35            self.model_provider = "custom.langchain:" + custom_model.__class__.__name__
 36            self.model_name = "custom.langchain:unknown_model"
 37            if hasattr(custom_model, "model_name") and isinstance(
 38                getattr(custom_model, "model_name"), str
 39            ):
 40                self.model_name = "custom.langchain:" + getattr(
 41                    custom_model, "model_name"
 42                )
 43            if hasattr(custom_model, "model") and isinstance(
 44                getattr(custom_model, "model"), str
 45            ):
 46                self.model_name = "custom.langchain:" + getattr(custom_model, "model")
 47        elif model_name is not None:
 48            self.model_name = model_name
 49            self.model_provider = provider or "custom.langchain.default_provider"
 50        else:
 51            raise ValueError(
 52                "model_name and provider must be provided if custom_model is not provided"
 53            )
 54
 55    def adapter_specific_instructions(self) -> str | None:
 56        if self.has_structured_output():
 57            return "Always respond with a tool call. Never respond with a human readable message."
 58        return None
 59
 60    async def model(self) -> LangChainModelType:
 61        # cached model
 62        if self._model:
 63            return self._model
 64
 65        self._model = await langchain_model_from(self.model_name, self.model_provider)
 66
 67        if self.has_structured_output():
 68            if not hasattr(self._model, "with_structured_output") or not callable(
 69                getattr(self._model, "with_structured_output")
 70            ):
 71                raise ValueError(
 72                    f"model {self._model} does not support structured output, cannot use output_json_schema"
 73                )
 74            # Langchain expects title/description to be at top level, on top of json schema
 75            output_schema = self.kiln_task.output_schema()
 76            if output_schema is None:
 77                raise ValueError(
 78                    f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}"
 79                )
 80            output_schema["title"] = "task_response"
 81            output_schema["description"] = "A response from the task"
 82            self._model = self._model.with_structured_output(
 83                output_schema, include_raw=True
 84            )
 85        return self._model
 86
 87    async def _run(self, input: Dict | str) -> RunOutput:
 88        model = await self.model()
 89        chain = model
 90        intermediate_outputs = {}
 91
 92        prompt = self.build_prompt()
 93        user_msg = self.prompt_builder.build_user_message(input)
 94        messages = [
 95            SystemMessage(content=prompt),
 96            HumanMessage(content=user_msg),
 97        ]
 98
 99        # COT with structured output
100        cot_prompt = self.prompt_builder.chain_of_thought_prompt()
101        if cot_prompt and self.has_structured_output():
102            # Base model (without structured output) used for COT message
103            base_model = await langchain_model_from(
104                self.model_name, self.model_provider
105            )
106            messages.append(
107                SystemMessage(content=cot_prompt),
108            )
109
110            cot_messages = [*messages]
111            cot_response = base_model.invoke(cot_messages)
112            intermediate_outputs["chain_of_thought"] = cot_response.content
113            messages.append(AIMessage(content=cot_response.content))
114            messages.append(
115                SystemMessage(content="Considering the above, return a final result.")
116            )
117        elif cot_prompt:
118            # for plaintext output, we just add COT instructions. We still only make one call.
119            messages.append(SystemMessage(content=cot_prompt))
120
121        response = chain.invoke(messages)
122
123        if self.has_structured_output():
124            if (
125                not isinstance(response, dict)
126                or "parsed" not in response
127                or not isinstance(response["parsed"], dict)
128            ):
129                raise RuntimeError(f"structured response not returned: {response}")
130            structured_response = response["parsed"]
131            return RunOutput(
132                output=self._munge_response(structured_response),
133                intermediate_outputs=intermediate_outputs,
134            )
135        else:
136            if not isinstance(response, BaseMessage):
137                raise RuntimeError(f"response is not a BaseMessage: {response}")
138            text_content = response.content
139            if not isinstance(text_content, str):
140                raise RuntimeError(f"response is not a string: {text_content}")
141            return RunOutput(
142                output=text_content,
143                intermediate_outputs=intermediate_outputs,
144            )
145
146    def adapter_info(self) -> AdapterInfo:
147        return AdapterInfo(
148            model_name=self.model_name,
149            model_provider=self.model_provider,
150            adapter_name="kiln_langchain_adapter",
151            prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
152        )
153
154    def _munge_response(self, response: Dict) -> Dict:
155        # Mistral Large tool calling format is a bit different. Convert to standard format.
156        if (
157            "name" in response
158            and response["name"] == "task_response"
159            and "arguments" in response
160        ):
161            return response["arguments"]
162        return response

Base class for AI model adapters that handle task execution.

This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.

Attributes: prompt_builder (BasePromptBuilder): Builder for constructing prompts for the model kiln_task (Task): The task configuration and metadata output_schema (dict | None): JSON schema for validating structured outputs input_schema (dict | None): JSON schema for validating structured inputs

LangChainPromptAdapter( kiln_task: kiln_ai.datamodel.Task, custom_model: langchain_core.language_models.chat_models.BaseChatModel | None = None, model_name: str | None = None, provider: str | None = None, prompt_builder: kiln_ai.adapters.prompt_builders.BasePromptBuilder | None = None)
22    def __init__(
23        self,
24        kiln_task: datamodel.Task,
25        custom_model: BaseChatModel | None = None,
26        model_name: str | None = None,
27        provider: str | None = None,
28        prompt_builder: BasePromptBuilder | None = None,
29    ):
30        super().__init__(kiln_task, prompt_builder=prompt_builder)
31        if custom_model is not None:
32            self._model = custom_model
33
34            # Attempt to infer model provider and name from custom model
35            self.model_provider = "custom.langchain:" + custom_model.__class__.__name__
36            self.model_name = "custom.langchain:unknown_model"
37            if hasattr(custom_model, "model_name") and isinstance(
38                getattr(custom_model, "model_name"), str
39            ):
40                self.model_name = "custom.langchain:" + getattr(
41                    custom_model, "model_name"
42                )
43            if hasattr(custom_model, "model") and isinstance(
44                getattr(custom_model, "model"), str
45            ):
46                self.model_name = "custom.langchain:" + getattr(custom_model, "model")
47        elif model_name is not None:
48            self.model_name = model_name
49            self.model_provider = provider or "custom.langchain.default_provider"
50        else:
51            raise ValueError(
52                "model_name and provider must be provided if custom_model is not provided"
53            )
def adapter_specific_instructions(self) -> str | None:
55    def adapter_specific_instructions(self) -> str | None:
56        if self.has_structured_output():
57            return "Always respond with a tool call. Never respond with a human readable message."
58        return None
async def model( self) -> Union[langchain_core.language_models.chat_models.BaseChatModel, langchain_core.runnables.base.Runnable[Union[langchain_core.prompt_values.PromptValue, str, Sequence[Union[langchain_core.messages.base.BaseMessage, list[str], tuple[str, str], str, dict[str, Any]]]], Union[Dict, pydantic.main.BaseModel]]]:
60    async def model(self) -> LangChainModelType:
61        # cached model
62        if self._model:
63            return self._model
64
65        self._model = await langchain_model_from(self.model_name, self.model_provider)
66
67        if self.has_structured_output():
68            if not hasattr(self._model, "with_structured_output") or not callable(
69                getattr(self._model, "with_structured_output")
70            ):
71                raise ValueError(
72                    f"model {self._model} does not support structured output, cannot use output_json_schema"
73                )
74            # Langchain expects title/description to be at top level, on top of json schema
75            output_schema = self.kiln_task.output_schema()
76            if output_schema is None:
77                raise ValueError(
78                    f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}"
79                )
80            output_schema["title"] = "task_response"
81            output_schema["description"] = "A response from the task"
82            self._model = self._model.with_structured_output(
83                output_schema, include_raw=True
84            )
85        return self._model
def adapter_info(self) -> kiln_ai.adapters.base_adapter.AdapterInfo:
146    def adapter_info(self) -> AdapterInfo:
147        return AdapterInfo(
148            model_name=self.model_name,
149            model_provider=self.model_provider,
150            adapter_name="kiln_langchain_adapter",
151            prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
152        )