kiln_ai.adapters.fine_tune.base_finetune
1from abc import ABC, abstractmethod 2from enum import Enum 3from typing import Literal 4 5from pydantic import BaseModel 6 7from kiln_ai.adapters.ml_model_list import built_in_models 8from kiln_ai.datamodel import DatasetSplit 9from kiln_ai.datamodel import Finetune as FinetuneModel 10from kiln_ai.utils.name_generator import generate_memorable_name 11 12 13class FineTuneStatusType(str, Enum): 14 """ 15 The status type of a fine-tune (running, completed, failed, etc). 16 """ 17 18 unknown = "unknown" # server error 19 pending = "pending" 20 running = "running" 21 completed = "completed" 22 failed = "failed" 23 24 25class FineTuneStatus(BaseModel): 26 """ 27 The status of a fine-tune, including a user friendly message. 28 """ 29 30 status: FineTuneStatusType 31 message: str | None = None 32 33 34class FineTuneParameter(BaseModel): 35 """ 36 A parameter for a fine-tune. Hyperparameters, etc. 37 """ 38 39 name: str 40 type: Literal["string", "int", "float", "bool"] 41 description: str 42 optional: bool = True 43 44 45# Add type mapping 46TYPE_MAP = { 47 "string": str, 48 "int": int, 49 "float": float, 50 "bool": bool, 51} 52 53 54class BaseFinetuneAdapter(ABC): 55 """ 56 A base class for fine-tuning adapters. 57 """ 58 59 def __init__( 60 self, 61 datamodel: FinetuneModel, 62 ): 63 self.datamodel = datamodel 64 65 @classmethod 66 def create_and_start( 67 cls, 68 dataset: DatasetSplit, 69 provider_id: str, 70 provider_base_model_id: str, 71 train_split_name: str, 72 system_message: str, 73 parameters: dict[str, str | int | float | bool] = {}, 74 name: str | None = None, 75 description: str | None = None, 76 validation_split_name: str | None = None, 77 ) -> tuple["BaseFinetuneAdapter", FinetuneModel]: 78 """ 79 Create and start a fine-tune. 80 """ 81 82 cls.check_valid_provider_model(provider_id, provider_base_model_id) 83 84 if not dataset.id: 85 raise ValueError("Dataset must have an id") 86 87 if train_split_name not in dataset.split_contents: 88 raise ValueError(f"Train split {train_split_name} not found in dataset") 89 90 if ( 91 validation_split_name 92 and validation_split_name not in dataset.split_contents 93 ): 94 raise ValueError( 95 f"Validation split {validation_split_name} not found in dataset" 96 ) 97 98 # Default name if not provided 99 if name is None: 100 name = generate_memorable_name() 101 102 cls.validate_parameters(parameters) 103 parent_task = dataset.parent_task() 104 if parent_task is None or not parent_task.path: 105 raise ValueError("Dataset must have a parent task with a path") 106 107 datamodel = FinetuneModel( 108 name=name, 109 description=description, 110 provider=provider_id, 111 base_model_id=provider_base_model_id, 112 dataset_split_id=dataset.id, 113 train_split_name=train_split_name, 114 validation_split_name=validation_split_name, 115 parameters=parameters, 116 system_message=system_message, 117 parent=parent_task, 118 ) 119 120 adapter = cls(datamodel) 121 adapter._start(dataset) 122 123 datamodel.save_to_file() 124 125 return adapter, datamodel 126 127 @abstractmethod 128 def _start(self, dataset: DatasetSplit) -> None: 129 """ 130 Start the fine-tune. 131 """ 132 pass 133 134 @abstractmethod 135 def status(self) -> FineTuneStatus: 136 """ 137 Get the status of the fine-tune. 138 """ 139 pass 140 141 @classmethod 142 def available_parameters(cls) -> list[FineTuneParameter]: 143 """ 144 Returns a list of parameters that can be provided for this fine-tune. 145 """ 146 return [] 147 148 @classmethod 149 def validate_parameters( 150 cls, parameters: dict[str, str | int | float | bool] 151 ) -> None: 152 """ 153 Validate the parameters for this fine-tune. 154 """ 155 # Check required parameters and parameter types 156 available_parameters = cls.available_parameters() 157 for parameter in available_parameters: 158 if not parameter.optional and parameter.name not in parameters: 159 raise ValueError(f"Parameter {parameter.name} is required") 160 elif parameter.name in parameters: 161 # check parameter is correct type 162 expected_type = TYPE_MAP[parameter.type] 163 value = parameters[parameter.name] 164 165 # Strict type checking for numeric types 166 if expected_type is float and not isinstance(value, float): 167 raise ValueError( 168 f"Parameter {parameter.name} must be a float, got {type(value)}" 169 ) 170 elif expected_type is int and not isinstance(value, int): 171 raise ValueError( 172 f"Parameter {parameter.name} must be an integer, got {type(value)}" 173 ) 174 elif not isinstance(value, expected_type): 175 raise ValueError( 176 f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}" 177 ) 178 179 allowed_parameters = [p.name for p in available_parameters] 180 for parameter_key in parameters: 181 if parameter_key not in allowed_parameters: 182 raise ValueError(f"Parameter {parameter_key} is not available") 183 184 @classmethod 185 def check_valid_provider_model( 186 cls, provider_id: str, provider_base_model_id: str 187 ) -> None: 188 """ 189 Check if the provider and base model are valid. 190 """ 191 for model in built_in_models: 192 for provider in model.providers: 193 if ( 194 provider.name == provider_id 195 and provider.provider_finetune_id == provider_base_model_id 196 ): 197 return 198 raise ValueError( 199 f"Provider {provider_id} with base model {provider_base_model_id} is not available" 200 )
14class FineTuneStatusType(str, Enum): 15 """ 16 The status type of a fine-tune (running, completed, failed, etc). 17 """ 18 19 unknown = "unknown" # server error 20 pending = "pending" 21 running = "running" 22 completed = "completed" 23 failed = "failed"
The status type of a fine-tune (running, completed, failed, etc).
26class FineTuneStatus(BaseModel): 27 """ 28 The status of a fine-tune, including a user friendly message. 29 """ 30 31 status: FineTuneStatusType 32 message: str | None = None
The status of a fine-tune, including a user friendly message.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Metadata about the fields defined on the model,
mapping of field names to [FieldInfo
][pydantic.fields.FieldInfo] objects.
This replaces Model.__fields__
from Pydantic V1.
35class FineTuneParameter(BaseModel): 36 """ 37 A parameter for a fine-tune. Hyperparameters, etc. 38 """ 39 40 name: str 41 type: Literal["string", "int", "float", "bool"] 42 description: str 43 optional: bool = True
A parameter for a fine-tune. Hyperparameters, etc.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Metadata about the fields defined on the model,
mapping of field names to [FieldInfo
][pydantic.fields.FieldInfo] objects.
This replaces Model.__fields__
from Pydantic V1.
55class BaseFinetuneAdapter(ABC): 56 """ 57 A base class for fine-tuning adapters. 58 """ 59 60 def __init__( 61 self, 62 datamodel: FinetuneModel, 63 ): 64 self.datamodel = datamodel 65 66 @classmethod 67 def create_and_start( 68 cls, 69 dataset: DatasetSplit, 70 provider_id: str, 71 provider_base_model_id: str, 72 train_split_name: str, 73 system_message: str, 74 parameters: dict[str, str | int | float | bool] = {}, 75 name: str | None = None, 76 description: str | None = None, 77 validation_split_name: str | None = None, 78 ) -> tuple["BaseFinetuneAdapter", FinetuneModel]: 79 """ 80 Create and start a fine-tune. 81 """ 82 83 cls.check_valid_provider_model(provider_id, provider_base_model_id) 84 85 if not dataset.id: 86 raise ValueError("Dataset must have an id") 87 88 if train_split_name not in dataset.split_contents: 89 raise ValueError(f"Train split {train_split_name} not found in dataset") 90 91 if ( 92 validation_split_name 93 and validation_split_name not in dataset.split_contents 94 ): 95 raise ValueError( 96 f"Validation split {validation_split_name} not found in dataset" 97 ) 98 99 # Default name if not provided 100 if name is None: 101 name = generate_memorable_name() 102 103 cls.validate_parameters(parameters) 104 parent_task = dataset.parent_task() 105 if parent_task is None or not parent_task.path: 106 raise ValueError("Dataset must have a parent task with a path") 107 108 datamodel = FinetuneModel( 109 name=name, 110 description=description, 111 provider=provider_id, 112 base_model_id=provider_base_model_id, 113 dataset_split_id=dataset.id, 114 train_split_name=train_split_name, 115 validation_split_name=validation_split_name, 116 parameters=parameters, 117 system_message=system_message, 118 parent=parent_task, 119 ) 120 121 adapter = cls(datamodel) 122 adapter._start(dataset) 123 124 datamodel.save_to_file() 125 126 return adapter, datamodel 127 128 @abstractmethod 129 def _start(self, dataset: DatasetSplit) -> None: 130 """ 131 Start the fine-tune. 132 """ 133 pass 134 135 @abstractmethod 136 def status(self) -> FineTuneStatus: 137 """ 138 Get the status of the fine-tune. 139 """ 140 pass 141 142 @classmethod 143 def available_parameters(cls) -> list[FineTuneParameter]: 144 """ 145 Returns a list of parameters that can be provided for this fine-tune. 146 """ 147 return [] 148 149 @classmethod 150 def validate_parameters( 151 cls, parameters: dict[str, str | int | float | bool] 152 ) -> None: 153 """ 154 Validate the parameters for this fine-tune. 155 """ 156 # Check required parameters and parameter types 157 available_parameters = cls.available_parameters() 158 for parameter in available_parameters: 159 if not parameter.optional and parameter.name not in parameters: 160 raise ValueError(f"Parameter {parameter.name} is required") 161 elif parameter.name in parameters: 162 # check parameter is correct type 163 expected_type = TYPE_MAP[parameter.type] 164 value = parameters[parameter.name] 165 166 # Strict type checking for numeric types 167 if expected_type is float and not isinstance(value, float): 168 raise ValueError( 169 f"Parameter {parameter.name} must be a float, got {type(value)}" 170 ) 171 elif expected_type is int and not isinstance(value, int): 172 raise ValueError( 173 f"Parameter {parameter.name} must be an integer, got {type(value)}" 174 ) 175 elif not isinstance(value, expected_type): 176 raise ValueError( 177 f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}" 178 ) 179 180 allowed_parameters = [p.name for p in available_parameters] 181 for parameter_key in parameters: 182 if parameter_key not in allowed_parameters: 183 raise ValueError(f"Parameter {parameter_key} is not available") 184 185 @classmethod 186 def check_valid_provider_model( 187 cls, provider_id: str, provider_base_model_id: str 188 ) -> None: 189 """ 190 Check if the provider and base model are valid. 191 """ 192 for model in built_in_models: 193 for provider in model.providers: 194 if ( 195 provider.name == provider_id 196 and provider.provider_finetune_id == provider_base_model_id 197 ): 198 return 199 raise ValueError( 200 f"Provider {provider_id} with base model {provider_base_model_id} is not available" 201 )
A base class for fine-tuning adapters.
66 @classmethod 67 def create_and_start( 68 cls, 69 dataset: DatasetSplit, 70 provider_id: str, 71 provider_base_model_id: str, 72 train_split_name: str, 73 system_message: str, 74 parameters: dict[str, str | int | float | bool] = {}, 75 name: str | None = None, 76 description: str | None = None, 77 validation_split_name: str | None = None, 78 ) -> tuple["BaseFinetuneAdapter", FinetuneModel]: 79 """ 80 Create and start a fine-tune. 81 """ 82 83 cls.check_valid_provider_model(provider_id, provider_base_model_id) 84 85 if not dataset.id: 86 raise ValueError("Dataset must have an id") 87 88 if train_split_name not in dataset.split_contents: 89 raise ValueError(f"Train split {train_split_name} not found in dataset") 90 91 if ( 92 validation_split_name 93 and validation_split_name not in dataset.split_contents 94 ): 95 raise ValueError( 96 f"Validation split {validation_split_name} not found in dataset" 97 ) 98 99 # Default name if not provided 100 if name is None: 101 name = generate_memorable_name() 102 103 cls.validate_parameters(parameters) 104 parent_task = dataset.parent_task() 105 if parent_task is None or not parent_task.path: 106 raise ValueError("Dataset must have a parent task with a path") 107 108 datamodel = FinetuneModel( 109 name=name, 110 description=description, 111 provider=provider_id, 112 base_model_id=provider_base_model_id, 113 dataset_split_id=dataset.id, 114 train_split_name=train_split_name, 115 validation_split_name=validation_split_name, 116 parameters=parameters, 117 system_message=system_message, 118 parent=parent_task, 119 ) 120 121 adapter = cls(datamodel) 122 adapter._start(dataset) 123 124 datamodel.save_to_file() 125 126 return adapter, datamodel
Create and start a fine-tune.
135 @abstractmethod 136 def status(self) -> FineTuneStatus: 137 """ 138 Get the status of the fine-tune. 139 """ 140 pass
Get the status of the fine-tune.
142 @classmethod 143 def available_parameters(cls) -> list[FineTuneParameter]: 144 """ 145 Returns a list of parameters that can be provided for this fine-tune. 146 """ 147 return []
Returns a list of parameters that can be provided for this fine-tune.
149 @classmethod 150 def validate_parameters( 151 cls, parameters: dict[str, str | int | float | bool] 152 ) -> None: 153 """ 154 Validate the parameters for this fine-tune. 155 """ 156 # Check required parameters and parameter types 157 available_parameters = cls.available_parameters() 158 for parameter in available_parameters: 159 if not parameter.optional and parameter.name not in parameters: 160 raise ValueError(f"Parameter {parameter.name} is required") 161 elif parameter.name in parameters: 162 # check parameter is correct type 163 expected_type = TYPE_MAP[parameter.type] 164 value = parameters[parameter.name] 165 166 # Strict type checking for numeric types 167 if expected_type is float and not isinstance(value, float): 168 raise ValueError( 169 f"Parameter {parameter.name} must be a float, got {type(value)}" 170 ) 171 elif expected_type is int and not isinstance(value, int): 172 raise ValueError( 173 f"Parameter {parameter.name} must be an integer, got {type(value)}" 174 ) 175 elif not isinstance(value, expected_type): 176 raise ValueError( 177 f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}" 178 ) 179 180 allowed_parameters = [p.name for p in available_parameters] 181 for parameter_key in parameters: 182 if parameter_key not in allowed_parameters: 183 raise ValueError(f"Parameter {parameter_key} is not available")
Validate the parameters for this fine-tune.
185 @classmethod 186 def check_valid_provider_model( 187 cls, provider_id: str, provider_base_model_id: str 188 ) -> None: 189 """ 190 Check if the provider and base model are valid. 191 """ 192 for model in built_in_models: 193 for provider in model.providers: 194 if ( 195 provider.name == provider_id 196 and provider.provider_finetune_id == provider_base_model_id 197 ): 198 return 199 raise ValueError( 200 f"Provider {provider_id} with base model {provider_base_model_id} is not available" 201 )
Check if the provider and base model are valid.