kiln_ai.adapters.fine_tune.base_finetune
1from abc import ABC, abstractmethod 2from typing import Literal 3 4from pydantic import BaseModel 5 6from kiln_ai.adapters.ml_model_list import built_in_models 7from kiln_ai.datamodel import ( 8 DatasetSplit, 9 FinetuneDataStrategy, 10 FineTuneStatusType, 11 Task, 12) 13from kiln_ai.datamodel import Finetune as FinetuneModel 14from kiln_ai.utils.name_generator import generate_memorable_name 15 16 17class FineTuneStatus(BaseModel): 18 """ 19 The status of a fine-tune, including a user friendly message. 20 """ 21 22 status: FineTuneStatusType 23 message: str | None = None 24 25 26class FineTuneParameter(BaseModel): 27 """ 28 A parameter for a fine-tune. Hyperparameters, etc. 29 """ 30 31 name: str 32 type: Literal["string", "int", "float", "bool"] 33 description: str 34 optional: bool = True 35 36 37TYPE_MAP = { 38 "string": str, 39 "int": int, 40 "float": float, 41 "bool": bool, 42} 43 44 45class BaseFinetuneAdapter(ABC): 46 """ 47 A base class for fine-tuning adapters. 48 """ 49 50 def __init__( 51 self, 52 datamodel: FinetuneModel, 53 ): 54 self.datamodel = datamodel 55 56 @classmethod 57 async def create_and_start( 58 cls, 59 dataset: DatasetSplit, 60 provider_id: str, 61 provider_base_model_id: str, 62 train_split_name: str, 63 system_message: str, 64 thinking_instructions: str | None, 65 data_strategy: FinetuneDataStrategy, 66 parameters: dict[str, str | int | float | bool] = {}, 67 name: str | None = None, 68 description: str | None = None, 69 validation_split_name: str | None = None, 70 ) -> tuple["BaseFinetuneAdapter", FinetuneModel]: 71 """ 72 Create and start a fine-tune. 73 """ 74 75 cls.check_valid_provider_model(provider_id, provider_base_model_id) 76 77 if not dataset.id: 78 raise ValueError("Dataset must have an id") 79 80 if train_split_name not in dataset.split_contents: 81 raise ValueError(f"Train split {train_split_name} not found in dataset") 82 83 if ( 84 validation_split_name 85 and validation_split_name not in dataset.split_contents 86 ): 87 raise ValueError( 88 f"Validation split {validation_split_name} not found in dataset" 89 ) 90 91 # Default name if not provided 92 if name is None: 93 name = generate_memorable_name() 94 95 cls.validate_parameters(parameters) 96 parent_task = dataset.parent_task() 97 if parent_task is None or not parent_task.path: 98 raise ValueError("Dataset must have a parent task with a path") 99 100 datamodel = FinetuneModel( 101 name=name, 102 description=description, 103 provider=provider_id, 104 base_model_id=provider_base_model_id, 105 dataset_split_id=dataset.id, 106 train_split_name=train_split_name, 107 validation_split_name=validation_split_name, 108 parameters=parameters, 109 system_message=cls.augment_system_message(system_message, parent_task), 110 thinking_instructions=thinking_instructions, 111 parent=parent_task, 112 data_strategy=data_strategy, 113 ) 114 115 adapter = cls(datamodel) 116 await adapter._start(dataset) 117 118 datamodel.save_to_file() 119 120 return adapter, datamodel 121 122 @classmethod 123 def augment_system_message(cls, system_message: str, task: Task) -> str: 124 """ 125 Augment the system message with additional instructions, such as JSON instructions. 126 """ 127 128 # Base implementation does nothing, can be overridden by subclasses 129 return system_message 130 131 @abstractmethod 132 async def _start(self, dataset: DatasetSplit) -> None: 133 """ 134 Start the fine-tune. 135 """ 136 pass 137 138 @abstractmethod 139 async def status(self) -> FineTuneStatus: 140 """ 141 Get the status of the fine-tune. 142 """ 143 pass 144 145 @classmethod 146 def available_parameters(cls) -> list[FineTuneParameter]: 147 """ 148 Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc. 149 """ 150 return [] 151 152 @classmethod 153 def validate_parameters( 154 cls, parameters: dict[str, str | int | float | bool] 155 ) -> None: 156 """ 157 Validate the parameters for this fine-tune. 158 """ 159 # Check required parameters and parameter types 160 available_parameters = cls.available_parameters() 161 for parameter in available_parameters: 162 if not parameter.optional and parameter.name not in parameters: 163 raise ValueError(f"Parameter {parameter.name} is required") 164 elif parameter.name in parameters: 165 # check parameter is correct type 166 expected_type = TYPE_MAP[parameter.type] 167 value = parameters[parameter.name] 168 169 # Strict type checking for numeric types 170 if expected_type is float and not isinstance(value, float): 171 raise ValueError( 172 f"Parameter {parameter.name} must be a float, got {type(value)}" 173 ) 174 elif expected_type is int and not isinstance(value, int): 175 raise ValueError( 176 f"Parameter {parameter.name} must be an integer, got {type(value)}" 177 ) 178 elif not isinstance(value, expected_type): 179 raise ValueError( 180 f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}" 181 ) 182 183 allowed_parameters = [p.name for p in available_parameters] 184 for parameter_key in parameters: 185 if parameter_key not in allowed_parameters: 186 raise ValueError(f"Parameter {parameter_key} is not available") 187 188 @classmethod 189 def check_valid_provider_model( 190 cls, provider_id: str, provider_base_model_id: str 191 ) -> None: 192 """ 193 Check if the provider and base model are valid. 194 """ 195 for model in built_in_models: 196 for provider in model.providers: 197 if ( 198 provider.name == provider_id 199 and provider.provider_finetune_id == provider_base_model_id 200 ): 201 return 202 raise ValueError( 203 f"Provider {provider_id} with base model {provider_base_model_id} is not available" 204 )
class
FineTuneStatus(pydantic.main.BaseModel):
18class FineTuneStatus(BaseModel): 19 """ 20 The status of a fine-tune, including a user friendly message. 21 """ 22 23 status: FineTuneStatusType 24 message: str | None = None
The status of a fine-tune, including a user friendly message.
class
FineTuneParameter(pydantic.main.BaseModel):
27class FineTuneParameter(BaseModel): 28 """ 29 A parameter for a fine-tune. Hyperparameters, etc. 30 """ 31 32 name: str 33 type: Literal["string", "int", "float", "bool"] 34 description: str 35 optional: bool = True
A parameter for a fine-tune. Hyperparameters, etc.
TYPE_MAP =
{'string': <class 'str'>, 'int': <class 'int'>, 'float': <class 'float'>, 'bool': <class 'bool'>}
class
BaseFinetuneAdapter(abc.ABC):
46class BaseFinetuneAdapter(ABC): 47 """ 48 A base class for fine-tuning adapters. 49 """ 50 51 def __init__( 52 self, 53 datamodel: FinetuneModel, 54 ): 55 self.datamodel = datamodel 56 57 @classmethod 58 async def create_and_start( 59 cls, 60 dataset: DatasetSplit, 61 provider_id: str, 62 provider_base_model_id: str, 63 train_split_name: str, 64 system_message: str, 65 thinking_instructions: str | None, 66 data_strategy: FinetuneDataStrategy, 67 parameters: dict[str, str | int | float | bool] = {}, 68 name: str | None = None, 69 description: str | None = None, 70 validation_split_name: str | None = None, 71 ) -> tuple["BaseFinetuneAdapter", FinetuneModel]: 72 """ 73 Create and start a fine-tune. 74 """ 75 76 cls.check_valid_provider_model(provider_id, provider_base_model_id) 77 78 if not dataset.id: 79 raise ValueError("Dataset must have an id") 80 81 if train_split_name not in dataset.split_contents: 82 raise ValueError(f"Train split {train_split_name} not found in dataset") 83 84 if ( 85 validation_split_name 86 and validation_split_name not in dataset.split_contents 87 ): 88 raise ValueError( 89 f"Validation split {validation_split_name} not found in dataset" 90 ) 91 92 # Default name if not provided 93 if name is None: 94 name = generate_memorable_name() 95 96 cls.validate_parameters(parameters) 97 parent_task = dataset.parent_task() 98 if parent_task is None or not parent_task.path: 99 raise ValueError("Dataset must have a parent task with a path") 100 101 datamodel = FinetuneModel( 102 name=name, 103 description=description, 104 provider=provider_id, 105 base_model_id=provider_base_model_id, 106 dataset_split_id=dataset.id, 107 train_split_name=train_split_name, 108 validation_split_name=validation_split_name, 109 parameters=parameters, 110 system_message=cls.augment_system_message(system_message, parent_task), 111 thinking_instructions=thinking_instructions, 112 parent=parent_task, 113 data_strategy=data_strategy, 114 ) 115 116 adapter = cls(datamodel) 117 await adapter._start(dataset) 118 119 datamodel.save_to_file() 120 121 return adapter, datamodel 122 123 @classmethod 124 def augment_system_message(cls, system_message: str, task: Task) -> str: 125 """ 126 Augment the system message with additional instructions, such as JSON instructions. 127 """ 128 129 # Base implementation does nothing, can be overridden by subclasses 130 return system_message 131 132 @abstractmethod 133 async def _start(self, dataset: DatasetSplit) -> None: 134 """ 135 Start the fine-tune. 136 """ 137 pass 138 139 @abstractmethod 140 async def status(self) -> FineTuneStatus: 141 """ 142 Get the status of the fine-tune. 143 """ 144 pass 145 146 @classmethod 147 def available_parameters(cls) -> list[FineTuneParameter]: 148 """ 149 Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc. 150 """ 151 return [] 152 153 @classmethod 154 def validate_parameters( 155 cls, parameters: dict[str, str | int | float | bool] 156 ) -> None: 157 """ 158 Validate the parameters for this fine-tune. 159 """ 160 # Check required parameters and parameter types 161 available_parameters = cls.available_parameters() 162 for parameter in available_parameters: 163 if not parameter.optional and parameter.name not in parameters: 164 raise ValueError(f"Parameter {parameter.name} is required") 165 elif parameter.name in parameters: 166 # check parameter is correct type 167 expected_type = TYPE_MAP[parameter.type] 168 value = parameters[parameter.name] 169 170 # Strict type checking for numeric types 171 if expected_type is float and not isinstance(value, float): 172 raise ValueError( 173 f"Parameter {parameter.name} must be a float, got {type(value)}" 174 ) 175 elif expected_type is int and not isinstance(value, int): 176 raise ValueError( 177 f"Parameter {parameter.name} must be an integer, got {type(value)}" 178 ) 179 elif not isinstance(value, expected_type): 180 raise ValueError( 181 f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}" 182 ) 183 184 allowed_parameters = [p.name for p in available_parameters] 185 for parameter_key in parameters: 186 if parameter_key not in allowed_parameters: 187 raise ValueError(f"Parameter {parameter_key} is not available") 188 189 @classmethod 190 def check_valid_provider_model( 191 cls, provider_id: str, provider_base_model_id: str 192 ) -> None: 193 """ 194 Check if the provider and base model are valid. 195 """ 196 for model in built_in_models: 197 for provider in model.providers: 198 if ( 199 provider.name == provider_id 200 and provider.provider_finetune_id == provider_base_model_id 201 ): 202 return 203 raise ValueError( 204 f"Provider {provider_id} with base model {provider_base_model_id} is not available" 205 )
A base class for fine-tuning adapters.
@classmethod
async def
create_and_start( cls, dataset: kiln_ai.datamodel.DatasetSplit, provider_id: str, provider_base_model_id: str, train_split_name: str, system_message: str, thinking_instructions: str | None, data_strategy: kiln_ai.datamodel.FinetuneDataStrategy, parameters: dict[str, str | int | float | bool] = {}, name: str | None = None, description: str | None = None, validation_split_name: str | None = None) -> tuple[BaseFinetuneAdapter, kiln_ai.datamodel.Finetune]:
57 @classmethod 58 async def create_and_start( 59 cls, 60 dataset: DatasetSplit, 61 provider_id: str, 62 provider_base_model_id: str, 63 train_split_name: str, 64 system_message: str, 65 thinking_instructions: str | None, 66 data_strategy: FinetuneDataStrategy, 67 parameters: dict[str, str | int | float | bool] = {}, 68 name: str | None = None, 69 description: str | None = None, 70 validation_split_name: str | None = None, 71 ) -> tuple["BaseFinetuneAdapter", FinetuneModel]: 72 """ 73 Create and start a fine-tune. 74 """ 75 76 cls.check_valid_provider_model(provider_id, provider_base_model_id) 77 78 if not dataset.id: 79 raise ValueError("Dataset must have an id") 80 81 if train_split_name not in dataset.split_contents: 82 raise ValueError(f"Train split {train_split_name} not found in dataset") 83 84 if ( 85 validation_split_name 86 and validation_split_name not in dataset.split_contents 87 ): 88 raise ValueError( 89 f"Validation split {validation_split_name} not found in dataset" 90 ) 91 92 # Default name if not provided 93 if name is None: 94 name = generate_memorable_name() 95 96 cls.validate_parameters(parameters) 97 parent_task = dataset.parent_task() 98 if parent_task is None or not parent_task.path: 99 raise ValueError("Dataset must have a parent task with a path") 100 101 datamodel = FinetuneModel( 102 name=name, 103 description=description, 104 provider=provider_id, 105 base_model_id=provider_base_model_id, 106 dataset_split_id=dataset.id, 107 train_split_name=train_split_name, 108 validation_split_name=validation_split_name, 109 parameters=parameters, 110 system_message=cls.augment_system_message(system_message, parent_task), 111 thinking_instructions=thinking_instructions, 112 parent=parent_task, 113 data_strategy=data_strategy, 114 ) 115 116 adapter = cls(datamodel) 117 await adapter._start(dataset) 118 119 datamodel.save_to_file() 120 121 return adapter, datamodel
Create and start a fine-tune.
@classmethod
def
augment_system_message(cls, system_message: str, task: kiln_ai.datamodel.Task) -> str:
123 @classmethod 124 def augment_system_message(cls, system_message: str, task: Task) -> str: 125 """ 126 Augment the system message with additional instructions, such as JSON instructions. 127 """ 128 129 # Base implementation does nothing, can be overridden by subclasses 130 return system_message
Augment the system message with additional instructions, such as JSON instructions.
139 @abstractmethod 140 async def status(self) -> FineTuneStatus: 141 """ 142 Get the status of the fine-tune. 143 """ 144 pass
Get the status of the fine-tune.
146 @classmethod 147 def available_parameters(cls) -> list[FineTuneParameter]: 148 """ 149 Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc. 150 """ 151 return []
Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
@classmethod
def
validate_parameters(cls, parameters: dict[str, str | int | float | bool]) -> None:
153 @classmethod 154 def validate_parameters( 155 cls, parameters: dict[str, str | int | float | bool] 156 ) -> None: 157 """ 158 Validate the parameters for this fine-tune. 159 """ 160 # Check required parameters and parameter types 161 available_parameters = cls.available_parameters() 162 for parameter in available_parameters: 163 if not parameter.optional and parameter.name not in parameters: 164 raise ValueError(f"Parameter {parameter.name} is required") 165 elif parameter.name in parameters: 166 # check parameter is correct type 167 expected_type = TYPE_MAP[parameter.type] 168 value = parameters[parameter.name] 169 170 # Strict type checking for numeric types 171 if expected_type is float and not isinstance(value, float): 172 raise ValueError( 173 f"Parameter {parameter.name} must be a float, got {type(value)}" 174 ) 175 elif expected_type is int and not isinstance(value, int): 176 raise ValueError( 177 f"Parameter {parameter.name} must be an integer, got {type(value)}" 178 ) 179 elif not isinstance(value, expected_type): 180 raise ValueError( 181 f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}" 182 ) 183 184 allowed_parameters = [p.name for p in available_parameters] 185 for parameter_key in parameters: 186 if parameter_key not in allowed_parameters: 187 raise ValueError(f"Parameter {parameter_key} is not available")
Validate the parameters for this fine-tune.
@classmethod
def
check_valid_provider_model(cls, provider_id: str, provider_base_model_id: str) -> None:
189 @classmethod 190 def check_valid_provider_model( 191 cls, provider_id: str, provider_base_model_id: str 192 ) -> None: 193 """ 194 Check if the provider and base model are valid. 195 """ 196 for model in built_in_models: 197 for provider in model.providers: 198 if ( 199 provider.name == provider_id 200 and provider.provider_finetune_id == provider_base_model_id 201 ): 202 return 203 raise ValueError( 204 f"Provider {provider_id} with base model {provider_base_model_id} is not available" 205 )
Check if the provider and base model are valid.