kiln_ai.adapters.fine_tune.base_finetune

  1from abc import ABC, abstractmethod
  2from enum import Enum
  3from typing import Literal
  4
  5from pydantic import BaseModel
  6
  7from kiln_ai.adapters.ml_model_list import built_in_models
  8from kiln_ai.datamodel import DatasetSplit
  9from kiln_ai.datamodel import Finetune as FinetuneModel
 10from kiln_ai.utils.name_generator import generate_memorable_name
 11
 12
 13class FineTuneStatusType(str, Enum):
 14    """
 15    The status type of a fine-tune (running, completed, failed, etc).
 16    """
 17
 18    unknown = "unknown"  # server error
 19    pending = "pending"
 20    running = "running"
 21    completed = "completed"
 22    failed = "failed"
 23
 24
 25class FineTuneStatus(BaseModel):
 26    """
 27    The status of a fine-tune, including a user friendly message.
 28    """
 29
 30    status: FineTuneStatusType
 31    message: str | None = None
 32
 33
 34class FineTuneParameter(BaseModel):
 35    """
 36    A parameter for a fine-tune. Hyperparameters, etc.
 37    """
 38
 39    name: str
 40    type: Literal["string", "int", "float", "bool"]
 41    description: str
 42    optional: bool = True
 43
 44
 45# Add type mapping
 46TYPE_MAP = {
 47    "string": str,
 48    "int": int,
 49    "float": float,
 50    "bool": bool,
 51}
 52
 53
 54class BaseFinetuneAdapter(ABC):
 55    """
 56    A base class for fine-tuning adapters.
 57    """
 58
 59    def __init__(
 60        self,
 61        datamodel: FinetuneModel,
 62    ):
 63        self.datamodel = datamodel
 64
 65    @classmethod
 66    def create_and_start(
 67        cls,
 68        dataset: DatasetSplit,
 69        provider_id: str,
 70        provider_base_model_id: str,
 71        train_split_name: str,
 72        system_message: str,
 73        parameters: dict[str, str | int | float | bool] = {},
 74        name: str | None = None,
 75        description: str | None = None,
 76        validation_split_name: str | None = None,
 77    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 78        """
 79        Create and start a fine-tune.
 80        """
 81
 82        cls.check_valid_provider_model(provider_id, provider_base_model_id)
 83
 84        if not dataset.id:
 85            raise ValueError("Dataset must have an id")
 86
 87        if train_split_name not in dataset.split_contents:
 88            raise ValueError(f"Train split {train_split_name} not found in dataset")
 89
 90        if (
 91            validation_split_name
 92            and validation_split_name not in dataset.split_contents
 93        ):
 94            raise ValueError(
 95                f"Validation split {validation_split_name} not found in dataset"
 96            )
 97
 98        # Default name if not provided
 99        if name is None:
100            name = generate_memorable_name()
101
102        cls.validate_parameters(parameters)
103        parent_task = dataset.parent_task()
104        if parent_task is None or not parent_task.path:
105            raise ValueError("Dataset must have a parent task with a path")
106
107        datamodel = FinetuneModel(
108            name=name,
109            description=description,
110            provider=provider_id,
111            base_model_id=provider_base_model_id,
112            dataset_split_id=dataset.id,
113            train_split_name=train_split_name,
114            validation_split_name=validation_split_name,
115            parameters=parameters,
116            system_message=system_message,
117            parent=parent_task,
118        )
119
120        adapter = cls(datamodel)
121        adapter._start(dataset)
122
123        datamodel.save_to_file()
124
125        return adapter, datamodel
126
127    @abstractmethod
128    def _start(self, dataset: DatasetSplit) -> None:
129        """
130        Start the fine-tune.
131        """
132        pass
133
134    @abstractmethod
135    def status(self) -> FineTuneStatus:
136        """
137        Get the status of the fine-tune.
138        """
139        pass
140
141    @classmethod
142    def available_parameters(cls) -> list[FineTuneParameter]:
143        """
144        Returns a list of parameters that can be provided for this fine-tune.
145        """
146        return []
147
148    @classmethod
149    def validate_parameters(
150        cls, parameters: dict[str, str | int | float | bool]
151    ) -> None:
152        """
153        Validate the parameters for this fine-tune.
154        """
155        # Check required parameters and parameter types
156        available_parameters = cls.available_parameters()
157        for parameter in available_parameters:
158            if not parameter.optional and parameter.name not in parameters:
159                raise ValueError(f"Parameter {parameter.name} is required")
160            elif parameter.name in parameters:
161                # check parameter is correct type
162                expected_type = TYPE_MAP[parameter.type]
163                value = parameters[parameter.name]
164
165                # Strict type checking for numeric types
166                if expected_type is float and not isinstance(value, float):
167                    raise ValueError(
168                        f"Parameter {parameter.name} must be a float, got {type(value)}"
169                    )
170                elif expected_type is int and not isinstance(value, int):
171                    raise ValueError(
172                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
173                    )
174                elif not isinstance(value, expected_type):
175                    raise ValueError(
176                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
177                    )
178
179        allowed_parameters = [p.name for p in available_parameters]
180        for parameter_key in parameters:
181            if parameter_key not in allowed_parameters:
182                raise ValueError(f"Parameter {parameter_key} is not available")
183
184    @classmethod
185    def check_valid_provider_model(
186        cls, provider_id: str, provider_base_model_id: str
187    ) -> None:
188        """
189        Check if the provider and base model are valid.
190        """
191        for model in built_in_models:
192            for provider in model.providers:
193                if (
194                    provider.name == provider_id
195                    and provider.provider_finetune_id == provider_base_model_id
196                ):
197                    return
198        raise ValueError(
199            f"Provider {provider_id} with base model {provider_base_model_id} is not available"
200        )
class FineTuneStatusType(builtins.str, enum.Enum):
14class FineTuneStatusType(str, Enum):
15    """
16    The status type of a fine-tune (running, completed, failed, etc).
17    """
18
19    unknown = "unknown"  # server error
20    pending = "pending"
21    running = "running"
22    completed = "completed"
23    failed = "failed"

The status type of a fine-tune (running, completed, failed, etc).

unknown = <FineTuneStatusType.unknown: 'unknown'>
pending = <FineTuneStatusType.pending: 'pending'>
running = <FineTuneStatusType.running: 'running'>
completed = <FineTuneStatusType.completed: 'completed'>
failed = <FineTuneStatusType.failed: 'failed'>
class FineTuneStatus(pydantic.main.BaseModel):
26class FineTuneStatus(BaseModel):
27    """
28    The status of a fine-tune, including a user friendly message.
29    """
30
31    status: FineTuneStatusType
32    message: str | None = None

The status of a fine-tune, including a user friendly message.

message: str | None
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_fields: ClassVar[Dict[str, pydantic.fields.FieldInfo]] = {'status': FieldInfo(annotation=FineTuneStatusType, required=True), 'message': FieldInfo(annotation=Union[str, NoneType], required=False, default=None)}

Metadata about the fields defined on the model, mapping of field names to [FieldInfo][pydantic.fields.FieldInfo] objects.

This replaces Model.__fields__ from Pydantic V1.

model_computed_fields: ClassVar[Dict[str, pydantic.fields.ComputedFieldInfo]] = {}

A dictionary of computed field names and their corresponding ComputedFieldInfo objects.

class FineTuneParameter(pydantic.main.BaseModel):
35class FineTuneParameter(BaseModel):
36    """
37    A parameter for a fine-tune. Hyperparameters, etc.
38    """
39
40    name: str
41    type: Literal["string", "int", "float", "bool"]
42    description: str
43    optional: bool = True

A parameter for a fine-tune. Hyperparameters, etc.

name: str
type: Literal['string', 'int', 'float', 'bool']
description: str
optional: bool
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_fields: ClassVar[Dict[str, pydantic.fields.FieldInfo]] = {'name': FieldInfo(annotation=str, required=True), 'type': FieldInfo(annotation=Literal['string', 'int', 'float', 'bool'], required=True), 'description': FieldInfo(annotation=str, required=True), 'optional': FieldInfo(annotation=bool, required=False, default=True)}

Metadata about the fields defined on the model, mapping of field names to [FieldInfo][pydantic.fields.FieldInfo] objects.

This replaces Model.__fields__ from Pydantic V1.

model_computed_fields: ClassVar[Dict[str, pydantic.fields.ComputedFieldInfo]] = {}

A dictionary of computed field names and their corresponding ComputedFieldInfo objects.

TYPE_MAP = {'string': <class 'str'>, 'int': <class 'int'>, 'float': <class 'float'>, 'bool': <class 'bool'>}
class BaseFinetuneAdapter(abc.ABC):
 55class BaseFinetuneAdapter(ABC):
 56    """
 57    A base class for fine-tuning adapters.
 58    """
 59
 60    def __init__(
 61        self,
 62        datamodel: FinetuneModel,
 63    ):
 64        self.datamodel = datamodel
 65
 66    @classmethod
 67    def create_and_start(
 68        cls,
 69        dataset: DatasetSplit,
 70        provider_id: str,
 71        provider_base_model_id: str,
 72        train_split_name: str,
 73        system_message: str,
 74        parameters: dict[str, str | int | float | bool] = {},
 75        name: str | None = None,
 76        description: str | None = None,
 77        validation_split_name: str | None = None,
 78    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 79        """
 80        Create and start a fine-tune.
 81        """
 82
 83        cls.check_valid_provider_model(provider_id, provider_base_model_id)
 84
 85        if not dataset.id:
 86            raise ValueError("Dataset must have an id")
 87
 88        if train_split_name not in dataset.split_contents:
 89            raise ValueError(f"Train split {train_split_name} not found in dataset")
 90
 91        if (
 92            validation_split_name
 93            and validation_split_name not in dataset.split_contents
 94        ):
 95            raise ValueError(
 96                f"Validation split {validation_split_name} not found in dataset"
 97            )
 98
 99        # Default name if not provided
100        if name is None:
101            name = generate_memorable_name()
102
103        cls.validate_parameters(parameters)
104        parent_task = dataset.parent_task()
105        if parent_task is None or not parent_task.path:
106            raise ValueError("Dataset must have a parent task with a path")
107
108        datamodel = FinetuneModel(
109            name=name,
110            description=description,
111            provider=provider_id,
112            base_model_id=provider_base_model_id,
113            dataset_split_id=dataset.id,
114            train_split_name=train_split_name,
115            validation_split_name=validation_split_name,
116            parameters=parameters,
117            system_message=system_message,
118            parent=parent_task,
119        )
120
121        adapter = cls(datamodel)
122        adapter._start(dataset)
123
124        datamodel.save_to_file()
125
126        return adapter, datamodel
127
128    @abstractmethod
129    def _start(self, dataset: DatasetSplit) -> None:
130        """
131        Start the fine-tune.
132        """
133        pass
134
135    @abstractmethod
136    def status(self) -> FineTuneStatus:
137        """
138        Get the status of the fine-tune.
139        """
140        pass
141
142    @classmethod
143    def available_parameters(cls) -> list[FineTuneParameter]:
144        """
145        Returns a list of parameters that can be provided for this fine-tune.
146        """
147        return []
148
149    @classmethod
150    def validate_parameters(
151        cls, parameters: dict[str, str | int | float | bool]
152    ) -> None:
153        """
154        Validate the parameters for this fine-tune.
155        """
156        # Check required parameters and parameter types
157        available_parameters = cls.available_parameters()
158        for parameter in available_parameters:
159            if not parameter.optional and parameter.name not in parameters:
160                raise ValueError(f"Parameter {parameter.name} is required")
161            elif parameter.name in parameters:
162                # check parameter is correct type
163                expected_type = TYPE_MAP[parameter.type]
164                value = parameters[parameter.name]
165
166                # Strict type checking for numeric types
167                if expected_type is float and not isinstance(value, float):
168                    raise ValueError(
169                        f"Parameter {parameter.name} must be a float, got {type(value)}"
170                    )
171                elif expected_type is int and not isinstance(value, int):
172                    raise ValueError(
173                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
174                    )
175                elif not isinstance(value, expected_type):
176                    raise ValueError(
177                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
178                    )
179
180        allowed_parameters = [p.name for p in available_parameters]
181        for parameter_key in parameters:
182            if parameter_key not in allowed_parameters:
183                raise ValueError(f"Parameter {parameter_key} is not available")
184
185    @classmethod
186    def check_valid_provider_model(
187        cls, provider_id: str, provider_base_model_id: str
188    ) -> None:
189        """
190        Check if the provider and base model are valid.
191        """
192        for model in built_in_models:
193            for provider in model.providers:
194                if (
195                    provider.name == provider_id
196                    and provider.provider_finetune_id == provider_base_model_id
197                ):
198                    return
199        raise ValueError(
200            f"Provider {provider_id} with base model {provider_base_model_id} is not available"
201        )

A base class for fine-tuning adapters.

datamodel
@classmethod
def create_and_start( cls, dataset: kiln_ai.datamodel.DatasetSplit, provider_id: str, provider_base_model_id: str, train_split_name: str, system_message: str, parameters: dict[str, str | int | float | bool] = {}, name: str | None = None, description: str | None = None, validation_split_name: str | None = None) -> tuple[BaseFinetuneAdapter, kiln_ai.datamodel.Finetune]:
 66    @classmethod
 67    def create_and_start(
 68        cls,
 69        dataset: DatasetSplit,
 70        provider_id: str,
 71        provider_base_model_id: str,
 72        train_split_name: str,
 73        system_message: str,
 74        parameters: dict[str, str | int | float | bool] = {},
 75        name: str | None = None,
 76        description: str | None = None,
 77        validation_split_name: str | None = None,
 78    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 79        """
 80        Create and start a fine-tune.
 81        """
 82
 83        cls.check_valid_provider_model(provider_id, provider_base_model_id)
 84
 85        if not dataset.id:
 86            raise ValueError("Dataset must have an id")
 87
 88        if train_split_name not in dataset.split_contents:
 89            raise ValueError(f"Train split {train_split_name} not found in dataset")
 90
 91        if (
 92            validation_split_name
 93            and validation_split_name not in dataset.split_contents
 94        ):
 95            raise ValueError(
 96                f"Validation split {validation_split_name} not found in dataset"
 97            )
 98
 99        # Default name if not provided
100        if name is None:
101            name = generate_memorable_name()
102
103        cls.validate_parameters(parameters)
104        parent_task = dataset.parent_task()
105        if parent_task is None or not parent_task.path:
106            raise ValueError("Dataset must have a parent task with a path")
107
108        datamodel = FinetuneModel(
109            name=name,
110            description=description,
111            provider=provider_id,
112            base_model_id=provider_base_model_id,
113            dataset_split_id=dataset.id,
114            train_split_name=train_split_name,
115            validation_split_name=validation_split_name,
116            parameters=parameters,
117            system_message=system_message,
118            parent=parent_task,
119        )
120
121        adapter = cls(datamodel)
122        adapter._start(dataset)
123
124        datamodel.save_to_file()
125
126        return adapter, datamodel

Create and start a fine-tune.

@abstractmethod
def status(self) -> FineTuneStatus:
135    @abstractmethod
136    def status(self) -> FineTuneStatus:
137        """
138        Get the status of the fine-tune.
139        """
140        pass

Get the status of the fine-tune.

@classmethod
def available_parameters(cls) -> list[FineTuneParameter]:
142    @classmethod
143    def available_parameters(cls) -> list[FineTuneParameter]:
144        """
145        Returns a list of parameters that can be provided for this fine-tune.
146        """
147        return []

Returns a list of parameters that can be provided for this fine-tune.

@classmethod
def validate_parameters(cls, parameters: dict[str, str | int | float | bool]) -> None:
149    @classmethod
150    def validate_parameters(
151        cls, parameters: dict[str, str | int | float | bool]
152    ) -> None:
153        """
154        Validate the parameters for this fine-tune.
155        """
156        # Check required parameters and parameter types
157        available_parameters = cls.available_parameters()
158        for parameter in available_parameters:
159            if not parameter.optional and parameter.name not in parameters:
160                raise ValueError(f"Parameter {parameter.name} is required")
161            elif parameter.name in parameters:
162                # check parameter is correct type
163                expected_type = TYPE_MAP[parameter.type]
164                value = parameters[parameter.name]
165
166                # Strict type checking for numeric types
167                if expected_type is float and not isinstance(value, float):
168                    raise ValueError(
169                        f"Parameter {parameter.name} must be a float, got {type(value)}"
170                    )
171                elif expected_type is int and not isinstance(value, int):
172                    raise ValueError(
173                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
174                    )
175                elif not isinstance(value, expected_type):
176                    raise ValueError(
177                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
178                    )
179
180        allowed_parameters = [p.name for p in available_parameters]
181        for parameter_key in parameters:
182            if parameter_key not in allowed_parameters:
183                raise ValueError(f"Parameter {parameter_key} is not available")

Validate the parameters for this fine-tune.

@classmethod
def check_valid_provider_model(cls, provider_id: str, provider_base_model_id: str) -> None:
185    @classmethod
186    def check_valid_provider_model(
187        cls, provider_id: str, provider_base_model_id: str
188    ) -> None:
189        """
190        Check if the provider and base model are valid.
191        """
192        for model in built_in_models:
193            for provider in model.providers:
194                if (
195                    provider.name == provider_id
196                    and provider.provider_finetune_id == provider_base_model_id
197                ):
198                    return
199        raise ValueError(
200            f"Provider {provider_id} with base model {provider_base_model_id} is not available"
201        )

Check if the provider and base model are valid.