kiln_ai.adapters.fine_tune.dataset_split
1import math 2import random 3from typing import TYPE_CHECKING, Callable 4 5from pydantic import BaseModel, Field, model_validator 6 7from kiln_ai.datamodel.basemodel import NAME_FIELD, KilnParentedModel 8 9if TYPE_CHECKING: 10 from kiln_ai.datamodel import Task, TaskRun 11# Define the type alias for clarity 12DatasetFilter = Callable[["TaskRun"], bool] 13 14 15def AllDatasetFilter(_: "TaskRun") -> bool: 16 return True 17 18 19def HighRatingDatasetFilter(task_run: "TaskRun") -> bool: 20 if task_run.output is None or task_run.output.rating is None: 21 return False 22 return task_run.output.rating.is_high_quality() 23 24 25class DatasetSplitDefinition(BaseModel): 26 """ 27 A definition of a split in a dataset. 28 29 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 30 """ 31 32 name: str = NAME_FIELD 33 description: str | None = Field( 34 default=None, 35 description="A description of the dataset for you and your team. Not used in training.", 36 ) 37 percentage: float = Field( 38 ge=0.0, 39 le=1.0, 40 description="The percentage of the dataset that this split represents (between 0 and 1).", 41 ) 42 43 44class DatasetSplit(KilnParentedModel): 45 """ 46 A collection of task runs, with optional splits (train, test, validation) 47 48 You probably want to use DatasetSplit class from the datamodel module, which is has relationships to the Task and TaskRun models. 49 """ 50 51 # TODO: NAME_FIELD 52 name: str 53 description: str | None = Field( 54 default=None, 55 description="A description of the dataset for you and your team. Not used in training.", 56 ) 57 splits: list[DatasetSplitDefinition] = Field( 58 default_factory=list, 59 description="The splits in the dataset.", 60 ) 61 split_contents: dict[str, list[str]] = Field( 62 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 63 ) 64 65 @model_validator(mode="after") 66 def validate_split_percentages(self) -> "DatasetSplit": 67 total = sum(split.percentage for split in self.splits) 68 if not math.isclose(total, 1.0, rel_tol=1e-9): 69 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 70 return self 71 72 @classmethod 73 def from_task( 74 cls, 75 name: str, 76 task: "Task", 77 splits: list[DatasetSplitDefinition], 78 filter: DatasetFilter = AllDatasetFilter, 79 description: str | None = None, 80 ): 81 split_contents = cls.build_split_contents(task, splits, filter) 82 return cls( 83 parent=task, 84 name=name, 85 description=description, 86 splits=splits, 87 split_contents=split_contents, 88 ) 89 90 @classmethod 91 def build_split_contents( 92 cls, 93 task: "Task", 94 splits: list[DatasetSplitDefinition], 95 filter: DatasetFilter, 96 ) -> dict[str, list[str]]: 97 valid_ids = [] 98 for task_run in task.runs(): 99 if filter(task_run): 100 valid_ids.append(task_run.id) 101 102 # Shuffle and split by split percentage 103 random.shuffle(valid_ids) 104 split_contents = {} 105 start_idx = 0 106 remaining_items = len(valid_ids) 107 108 # Handle all splits except the last one 109 for split in splits[:-1]: 110 split_size = round(len(valid_ids) * split.percentage) 111 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 112 start_idx += split_size 113 remaining_items -= split_size 114 115 # Last split gets all remaining items (for rounding) 116 if splits: 117 split_contents[splits[-1].name] = valid_ids[start_idx:] 118 119 return split_contents
26class DatasetSplitDefinition(BaseModel): 27 """ 28 A definition of a split in a dataset. 29 30 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 31 """ 32 33 name: str = NAME_FIELD 34 description: str | None = Field( 35 default=None, 36 description="A description of the dataset for you and your team. Not used in training.", 37 ) 38 percentage: float = Field( 39 ge=0.0, 40 le=1.0, 41 description="The percentage of the dataset that this split represents (between 0 and 1).", 42 )
A definition of a split in a dataset.
Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Metadata about the fields defined on the model,
mapping of field names to [FieldInfo
][pydantic.fields.FieldInfo] objects.
This replaces Model.__fields__
from Pydantic V1.
45class DatasetSplit(KilnParentedModel): 46 """ 47 A collection of task runs, with optional splits (train, test, validation) 48 49 You probably want to use DatasetSplit class from the datamodel module, which is has relationships to the Task and TaskRun models. 50 """ 51 52 # TODO: NAME_FIELD 53 name: str 54 description: str | None = Field( 55 default=None, 56 description="A description of the dataset for you and your team. Not used in training.", 57 ) 58 splits: list[DatasetSplitDefinition] = Field( 59 default_factory=list, 60 description="The splits in the dataset.", 61 ) 62 split_contents: dict[str, list[str]] = Field( 63 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 64 ) 65 66 @model_validator(mode="after") 67 def validate_split_percentages(self) -> "DatasetSplit": 68 total = sum(split.percentage for split in self.splits) 69 if not math.isclose(total, 1.0, rel_tol=1e-9): 70 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 71 return self 72 73 @classmethod 74 def from_task( 75 cls, 76 name: str, 77 task: "Task", 78 splits: list[DatasetSplitDefinition], 79 filter: DatasetFilter = AllDatasetFilter, 80 description: str | None = None, 81 ): 82 split_contents = cls.build_split_contents(task, splits, filter) 83 return cls( 84 parent=task, 85 name=name, 86 description=description, 87 splits=splits, 88 split_contents=split_contents, 89 ) 90 91 @classmethod 92 def build_split_contents( 93 cls, 94 task: "Task", 95 splits: list[DatasetSplitDefinition], 96 filter: DatasetFilter, 97 ) -> dict[str, list[str]]: 98 valid_ids = [] 99 for task_run in task.runs(): 100 if filter(task_run): 101 valid_ids.append(task_run.id) 102 103 # Shuffle and split by split percentage 104 random.shuffle(valid_ids) 105 split_contents = {} 106 start_idx = 0 107 remaining_items = len(valid_ids) 108 109 # Handle all splits except the last one 110 for split in splits[:-1]: 111 split_size = round(len(valid_ids) * split.percentage) 112 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 113 start_idx += split_size 114 remaining_items -= split_size 115 116 # Last split gets all remaining items (for rounding) 117 if splits: 118 split_contents[splits[-1].name] = valid_ids[start_idx:] 119 120 return split_contents
A collection of task runs, with optional splits (train, test, validation)
You probably want to use DatasetSplit class from the datamodel module, which is has relationships to the Task and TaskRun models.
73 @classmethod 74 def from_task( 75 cls, 76 name: str, 77 task: "Task", 78 splits: list[DatasetSplitDefinition], 79 filter: DatasetFilter = AllDatasetFilter, 80 description: str | None = None, 81 ): 82 split_contents = cls.build_split_contents(task, splits, filter) 83 return cls( 84 parent=task, 85 name=name, 86 description=description, 87 splits=splits, 88 split_contents=split_contents, 89 )
91 @classmethod 92 def build_split_contents( 93 cls, 94 task: "Task", 95 splits: list[DatasetSplitDefinition], 96 filter: DatasetFilter, 97 ) -> dict[str, list[str]]: 98 valid_ids = [] 99 for task_run in task.runs(): 100 if filter(task_run): 101 valid_ids.append(task_run.id) 102 103 # Shuffle and split by split percentage 104 random.shuffle(valid_ids) 105 split_contents = {} 106 start_idx = 0 107 remaining_items = len(valid_ids) 108 109 # Handle all splits except the last one 110 for split in splits[:-1]: 111 split_size = round(len(valid_ids) * split.percentage) 112 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 113 start_idx += split_size 114 remaining_items -= split_size 115 116 # Last split gets all remaining items (for rounding) 117 if splits: 118 split_contents[splits[-1].name] = valid_ids[start_idx:] 119 120 return split_contents
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Metadata about the fields defined on the model,
mapping of field names to [FieldInfo
][pydantic.fields.FieldInfo] objects.
This replaces Model.__fields__
from Pydantic V1.
A dictionary of computed field names and their corresponding ComputedFieldInfo
objects.
124 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 125 """We need to both initialize private attributes and call the user-defined model_post_init 126 method. 127 """ 128 init_private_attributes(self, context) 129 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.