kiln_ai.adapters.fine_tune.base_finetune

  1from abc import ABC, abstractmethod
  2from typing import Literal
  3
  4from pydantic import BaseModel
  5
  6from kiln_ai.adapters.ml_model_list import built_in_models
  7from kiln_ai.datamodel import (
  8    DatasetSplit,
  9    FinetuneDataStrategy,
 10    FineTuneStatusType,
 11    Task,
 12)
 13from kiln_ai.datamodel import Finetune as FinetuneModel
 14from kiln_ai.utils.name_generator import generate_memorable_name
 15
 16
 17class FineTuneStatus(BaseModel):
 18    """
 19    The status of a fine-tune, including a user friendly message.
 20    """
 21
 22    status: FineTuneStatusType
 23    message: str | None = None
 24
 25
 26class FineTuneParameter(BaseModel):
 27    """
 28    A parameter for a fine-tune. Hyperparameters, etc.
 29    """
 30
 31    name: str
 32    type: Literal["string", "int", "float", "bool"]
 33    description: str
 34    optional: bool = True
 35
 36
 37TYPE_MAP = {
 38    "string": str,
 39    "int": int,
 40    "float": float,
 41    "bool": bool,
 42}
 43
 44
 45class BaseFinetuneAdapter(ABC):
 46    """
 47    A base class for fine-tuning adapters.
 48    """
 49
 50    def __init__(
 51        self,
 52        datamodel: FinetuneModel,
 53    ):
 54        self.datamodel = datamodel
 55
 56    @classmethod
 57    async def create_and_start(
 58        cls,
 59        dataset: DatasetSplit,
 60        provider_id: str,
 61        provider_base_model_id: str,
 62        train_split_name: str,
 63        system_message: str,
 64        thinking_instructions: str | None,
 65        data_strategy: FinetuneDataStrategy,
 66        parameters: dict[str, str | int | float | bool] = {},
 67        name: str | None = None,
 68        description: str | None = None,
 69        validation_split_name: str | None = None,
 70    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 71        """
 72        Create and start a fine-tune.
 73        """
 74
 75        cls.check_valid_provider_model(provider_id, provider_base_model_id)
 76
 77        if not dataset.id:
 78            raise ValueError("Dataset must have an id")
 79
 80        if train_split_name not in dataset.split_contents:
 81            raise ValueError(f"Train split {train_split_name} not found in dataset")
 82
 83        if (
 84            validation_split_name
 85            and validation_split_name not in dataset.split_contents
 86        ):
 87            raise ValueError(
 88                f"Validation split {validation_split_name} not found in dataset"
 89            )
 90
 91        # Default name if not provided
 92        if name is None:
 93            name = generate_memorable_name()
 94
 95        cls.validate_parameters(parameters)
 96        parent_task = dataset.parent_task()
 97        if parent_task is None or not parent_task.path:
 98            raise ValueError("Dataset must have a parent task with a path")
 99
100        datamodel = FinetuneModel(
101            name=name,
102            description=description,
103            provider=provider_id,
104            base_model_id=provider_base_model_id,
105            dataset_split_id=dataset.id,
106            train_split_name=train_split_name,
107            validation_split_name=validation_split_name,
108            parameters=parameters,
109            system_message=cls.augment_system_message(system_message, parent_task),
110            thinking_instructions=thinking_instructions,
111            parent=parent_task,
112            data_strategy=data_strategy,
113        )
114
115        adapter = cls(datamodel)
116        await adapter._start(dataset)
117
118        datamodel.save_to_file()
119
120        return adapter, datamodel
121
122    @classmethod
123    def augment_system_message(cls, system_message: str, task: Task) -> str:
124        """
125        Augment the system message with additional instructions, such as JSON instructions.
126        """
127
128        # Base implementation does nothing, can be overridden by subclasses
129        return system_message
130
131    @abstractmethod
132    async def _start(self, dataset: DatasetSplit) -> None:
133        """
134        Start the fine-tune.
135        """
136        pass
137
138    @abstractmethod
139    async def status(self) -> FineTuneStatus:
140        """
141        Get the status of the fine-tune.
142        """
143        pass
144
145    @classmethod
146    def available_parameters(cls) -> list[FineTuneParameter]:
147        """
148        Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
149        """
150        return []
151
152    @classmethod
153    def validate_parameters(
154        cls, parameters: dict[str, str | int | float | bool]
155    ) -> None:
156        """
157        Validate the parameters for this fine-tune.
158        """
159        # Check required parameters and parameter types
160        available_parameters = cls.available_parameters()
161        for parameter in available_parameters:
162            if not parameter.optional and parameter.name not in parameters:
163                raise ValueError(f"Parameter {parameter.name} is required")
164            elif parameter.name in parameters:
165                # check parameter is correct type
166                expected_type = TYPE_MAP[parameter.type]
167                value = parameters[parameter.name]
168
169                # Strict type checking for numeric types
170                if expected_type is float and not isinstance(value, float):
171                    raise ValueError(
172                        f"Parameter {parameter.name} must be a float, got {type(value)}"
173                    )
174                elif expected_type is int and not isinstance(value, int):
175                    raise ValueError(
176                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
177                    )
178                elif not isinstance(value, expected_type):
179                    raise ValueError(
180                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
181                    )
182
183        allowed_parameters = [p.name for p in available_parameters]
184        for parameter_key in parameters:
185            if parameter_key not in allowed_parameters:
186                raise ValueError(f"Parameter {parameter_key} is not available")
187
188    @classmethod
189    def check_valid_provider_model(
190        cls, provider_id: str, provider_base_model_id: str
191    ) -> None:
192        """
193        Check if the provider and base model are valid.
194        """
195        for model in built_in_models:
196            for provider in model.providers:
197                if (
198                    provider.name == provider_id
199                    and provider.provider_finetune_id == provider_base_model_id
200                ):
201                    return
202        raise ValueError(
203            f"Provider {provider_id} with base model {provider_base_model_id} is not available"
204        )
class FineTuneStatus(pydantic.main.BaseModel):
18class FineTuneStatus(BaseModel):
19    """
20    The status of a fine-tune, including a user friendly message.
21    """
22
23    status: FineTuneStatusType
24    message: str | None = None

The status of a fine-tune, including a user friendly message.

message: str | None
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class FineTuneParameter(pydantic.main.BaseModel):
27class FineTuneParameter(BaseModel):
28    """
29    A parameter for a fine-tune. Hyperparameters, etc.
30    """
31
32    name: str
33    type: Literal["string", "int", "float", "bool"]
34    description: str
35    optional: bool = True

A parameter for a fine-tune. Hyperparameters, etc.

name: str
type: Literal['string', 'int', 'float', 'bool']
description: str
optional: bool
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

TYPE_MAP = {'string': <class 'str'>, 'int': <class 'int'>, 'float': <class 'float'>, 'bool': <class 'bool'>}
class BaseFinetuneAdapter(abc.ABC):
 46class BaseFinetuneAdapter(ABC):
 47    """
 48    A base class for fine-tuning adapters.
 49    """
 50
 51    def __init__(
 52        self,
 53        datamodel: FinetuneModel,
 54    ):
 55        self.datamodel = datamodel
 56
 57    @classmethod
 58    async def create_and_start(
 59        cls,
 60        dataset: DatasetSplit,
 61        provider_id: str,
 62        provider_base_model_id: str,
 63        train_split_name: str,
 64        system_message: str,
 65        thinking_instructions: str | None,
 66        data_strategy: FinetuneDataStrategy,
 67        parameters: dict[str, str | int | float | bool] = {},
 68        name: str | None = None,
 69        description: str | None = None,
 70        validation_split_name: str | None = None,
 71    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 72        """
 73        Create and start a fine-tune.
 74        """
 75
 76        cls.check_valid_provider_model(provider_id, provider_base_model_id)
 77
 78        if not dataset.id:
 79            raise ValueError("Dataset must have an id")
 80
 81        if train_split_name not in dataset.split_contents:
 82            raise ValueError(f"Train split {train_split_name} not found in dataset")
 83
 84        if (
 85            validation_split_name
 86            and validation_split_name not in dataset.split_contents
 87        ):
 88            raise ValueError(
 89                f"Validation split {validation_split_name} not found in dataset"
 90            )
 91
 92        # Default name if not provided
 93        if name is None:
 94            name = generate_memorable_name()
 95
 96        cls.validate_parameters(parameters)
 97        parent_task = dataset.parent_task()
 98        if parent_task is None or not parent_task.path:
 99            raise ValueError("Dataset must have a parent task with a path")
100
101        datamodel = FinetuneModel(
102            name=name,
103            description=description,
104            provider=provider_id,
105            base_model_id=provider_base_model_id,
106            dataset_split_id=dataset.id,
107            train_split_name=train_split_name,
108            validation_split_name=validation_split_name,
109            parameters=parameters,
110            system_message=cls.augment_system_message(system_message, parent_task),
111            thinking_instructions=thinking_instructions,
112            parent=parent_task,
113            data_strategy=data_strategy,
114        )
115
116        adapter = cls(datamodel)
117        await adapter._start(dataset)
118
119        datamodel.save_to_file()
120
121        return adapter, datamodel
122
123    @classmethod
124    def augment_system_message(cls, system_message: str, task: Task) -> str:
125        """
126        Augment the system message with additional instructions, such as JSON instructions.
127        """
128
129        # Base implementation does nothing, can be overridden by subclasses
130        return system_message
131
132    @abstractmethod
133    async def _start(self, dataset: DatasetSplit) -> None:
134        """
135        Start the fine-tune.
136        """
137        pass
138
139    @abstractmethod
140    async def status(self) -> FineTuneStatus:
141        """
142        Get the status of the fine-tune.
143        """
144        pass
145
146    @classmethod
147    def available_parameters(cls) -> list[FineTuneParameter]:
148        """
149        Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
150        """
151        return []
152
153    @classmethod
154    def validate_parameters(
155        cls, parameters: dict[str, str | int | float | bool]
156    ) -> None:
157        """
158        Validate the parameters for this fine-tune.
159        """
160        # Check required parameters and parameter types
161        available_parameters = cls.available_parameters()
162        for parameter in available_parameters:
163            if not parameter.optional and parameter.name not in parameters:
164                raise ValueError(f"Parameter {parameter.name} is required")
165            elif parameter.name in parameters:
166                # check parameter is correct type
167                expected_type = TYPE_MAP[parameter.type]
168                value = parameters[parameter.name]
169
170                # Strict type checking for numeric types
171                if expected_type is float and not isinstance(value, float):
172                    raise ValueError(
173                        f"Parameter {parameter.name} must be a float, got {type(value)}"
174                    )
175                elif expected_type is int and not isinstance(value, int):
176                    raise ValueError(
177                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
178                    )
179                elif not isinstance(value, expected_type):
180                    raise ValueError(
181                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
182                    )
183
184        allowed_parameters = [p.name for p in available_parameters]
185        for parameter_key in parameters:
186            if parameter_key not in allowed_parameters:
187                raise ValueError(f"Parameter {parameter_key} is not available")
188
189    @classmethod
190    def check_valid_provider_model(
191        cls, provider_id: str, provider_base_model_id: str
192    ) -> None:
193        """
194        Check if the provider and base model are valid.
195        """
196        for model in built_in_models:
197            for provider in model.providers:
198                if (
199                    provider.name == provider_id
200                    and provider.provider_finetune_id == provider_base_model_id
201                ):
202                    return
203        raise ValueError(
204            f"Provider {provider_id} with base model {provider_base_model_id} is not available"
205        )

A base class for fine-tuning adapters.

datamodel
@classmethod
async def create_and_start( cls, dataset: kiln_ai.datamodel.DatasetSplit, provider_id: str, provider_base_model_id: str, train_split_name: str, system_message: str, thinking_instructions: str | None, data_strategy: kiln_ai.datamodel.FinetuneDataStrategy, parameters: dict[str, str | int | float | bool] = {}, name: str | None = None, description: str | None = None, validation_split_name: str | None = None) -> tuple[BaseFinetuneAdapter, kiln_ai.datamodel.Finetune]:
 57    @classmethod
 58    async def create_and_start(
 59        cls,
 60        dataset: DatasetSplit,
 61        provider_id: str,
 62        provider_base_model_id: str,
 63        train_split_name: str,
 64        system_message: str,
 65        thinking_instructions: str | None,
 66        data_strategy: FinetuneDataStrategy,
 67        parameters: dict[str, str | int | float | bool] = {},
 68        name: str | None = None,
 69        description: str | None = None,
 70        validation_split_name: str | None = None,
 71    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 72        """
 73        Create and start a fine-tune.
 74        """
 75
 76        cls.check_valid_provider_model(provider_id, provider_base_model_id)
 77
 78        if not dataset.id:
 79            raise ValueError("Dataset must have an id")
 80
 81        if train_split_name not in dataset.split_contents:
 82            raise ValueError(f"Train split {train_split_name} not found in dataset")
 83
 84        if (
 85            validation_split_name
 86            and validation_split_name not in dataset.split_contents
 87        ):
 88            raise ValueError(
 89                f"Validation split {validation_split_name} not found in dataset"
 90            )
 91
 92        # Default name if not provided
 93        if name is None:
 94            name = generate_memorable_name()
 95
 96        cls.validate_parameters(parameters)
 97        parent_task = dataset.parent_task()
 98        if parent_task is None or not parent_task.path:
 99            raise ValueError("Dataset must have a parent task with a path")
100
101        datamodel = FinetuneModel(
102            name=name,
103            description=description,
104            provider=provider_id,
105            base_model_id=provider_base_model_id,
106            dataset_split_id=dataset.id,
107            train_split_name=train_split_name,
108            validation_split_name=validation_split_name,
109            parameters=parameters,
110            system_message=cls.augment_system_message(system_message, parent_task),
111            thinking_instructions=thinking_instructions,
112            parent=parent_task,
113            data_strategy=data_strategy,
114        )
115
116        adapter = cls(datamodel)
117        await adapter._start(dataset)
118
119        datamodel.save_to_file()
120
121        return adapter, datamodel

Create and start a fine-tune.

@classmethod
def augment_system_message(cls, system_message: str, task: kiln_ai.datamodel.Task) -> str:
123    @classmethod
124    def augment_system_message(cls, system_message: str, task: Task) -> str:
125        """
126        Augment the system message with additional instructions, such as JSON instructions.
127        """
128
129        # Base implementation does nothing, can be overridden by subclasses
130        return system_message

Augment the system message with additional instructions, such as JSON instructions.

@abstractmethod
async def status(self) -> FineTuneStatus:
139    @abstractmethod
140    async def status(self) -> FineTuneStatus:
141        """
142        Get the status of the fine-tune.
143        """
144        pass

Get the status of the fine-tune.

@classmethod
def available_parameters(cls) -> list[FineTuneParameter]:
146    @classmethod
147    def available_parameters(cls) -> list[FineTuneParameter]:
148        """
149        Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
150        """
151        return []

Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.

@classmethod
def validate_parameters(cls, parameters: dict[str, str | int | float | bool]) -> None:
153    @classmethod
154    def validate_parameters(
155        cls, parameters: dict[str, str | int | float | bool]
156    ) -> None:
157        """
158        Validate the parameters for this fine-tune.
159        """
160        # Check required parameters and parameter types
161        available_parameters = cls.available_parameters()
162        for parameter in available_parameters:
163            if not parameter.optional and parameter.name not in parameters:
164                raise ValueError(f"Parameter {parameter.name} is required")
165            elif parameter.name in parameters:
166                # check parameter is correct type
167                expected_type = TYPE_MAP[parameter.type]
168                value = parameters[parameter.name]
169
170                # Strict type checking for numeric types
171                if expected_type is float and not isinstance(value, float):
172                    raise ValueError(
173                        f"Parameter {parameter.name} must be a float, got {type(value)}"
174                    )
175                elif expected_type is int and not isinstance(value, int):
176                    raise ValueError(
177                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
178                    )
179                elif not isinstance(value, expected_type):
180                    raise ValueError(
181                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
182                    )
183
184        allowed_parameters = [p.name for p in available_parameters]
185        for parameter_key in parameters:
186            if parameter_key not in allowed_parameters:
187                raise ValueError(f"Parameter {parameter_key} is not available")

Validate the parameters for this fine-tune.

@classmethod
def check_valid_provider_model(cls, provider_id: str, provider_base_model_id: str) -> None:
189    @classmethod
190    def check_valid_provider_model(
191        cls, provider_id: str, provider_base_model_id: str
192    ) -> None:
193        """
194        Check if the provider and base model are valid.
195        """
196        for model in built_in_models:
197            for provider in model.providers:
198                if (
199                    provider.name == provider_id
200                    and provider.provider_finetune_id == provider_base_model_id
201                ):
202                    return
203        raise ValueError(
204            f"Provider {provider_id} with base model {provider_base_model_id} is not available"
205        )

Check if the provider and base model are valid.