kiln_ai.adapters.fine_tune.dataset_split

  1import math
  2import random
  3from typing import TYPE_CHECKING, Callable
  4
  5from pydantic import BaseModel, Field, model_validator
  6
  7from kiln_ai.datamodel.basemodel import NAME_FIELD, KilnParentedModel
  8
  9if TYPE_CHECKING:
 10    from kiln_ai.datamodel import Task, TaskRun
 11# Define the type alias for clarity
 12DatasetFilter = Callable[["TaskRun"], bool]
 13
 14
 15def AllDatasetFilter(_: "TaskRun") -> bool:
 16    return True
 17
 18
 19def HighRatingDatasetFilter(task_run: "TaskRun") -> bool:
 20    if task_run.output is None or task_run.output.rating is None:
 21        return False
 22    return task_run.output.rating.is_high_quality()
 23
 24
 25class DatasetSplitDefinition(BaseModel):
 26    """
 27    A definition of a split in a dataset.
 28
 29    Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
 30    """
 31
 32    name: str = NAME_FIELD
 33    description: str | None = Field(
 34        default=None,
 35        description="A description of the dataset for you and your team. Not used in training.",
 36    )
 37    percentage: float = Field(
 38        ge=0.0,
 39        le=1.0,
 40        description="The percentage of the dataset that this split represents (between 0 and 1).",
 41    )
 42
 43
 44class DatasetSplit(KilnParentedModel):
 45    """
 46    A collection of task runs, with optional splits (train, test, validation)
 47
 48    You probably want to use DatasetSplit class from the datamodel module, which is has relationships to the Task and TaskRun models.
 49    """
 50
 51    # TODO: NAME_FIELD
 52    name: str
 53    description: str | None = Field(
 54        default=None,
 55        description="A description of the dataset for you and your team. Not used in training.",
 56    )
 57    splits: list[DatasetSplitDefinition] = Field(
 58        default_factory=list,
 59        description="The splits in the dataset.",
 60    )
 61    split_contents: dict[str, list[str]] = Field(
 62        description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
 63    )
 64
 65    @model_validator(mode="after")
 66    def validate_split_percentages(self) -> "DatasetSplit":
 67        total = sum(split.percentage for split in self.splits)
 68        if not math.isclose(total, 1.0, rel_tol=1e-9):
 69            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
 70        return self
 71
 72    @classmethod
 73    def from_task(
 74        cls,
 75        name: str,
 76        task: "Task",
 77        splits: list[DatasetSplitDefinition],
 78        filter: DatasetFilter = AllDatasetFilter,
 79        description: str | None = None,
 80    ):
 81        split_contents = cls.build_split_contents(task, splits, filter)
 82        return cls(
 83            parent=task,
 84            name=name,
 85            description=description,
 86            splits=splits,
 87            split_contents=split_contents,
 88        )
 89
 90    @classmethod
 91    def build_split_contents(
 92        cls,
 93        task: "Task",
 94        splits: list[DatasetSplitDefinition],
 95        filter: DatasetFilter,
 96    ) -> dict[str, list[str]]:
 97        valid_ids = []
 98        for task_run in task.runs():
 99            if filter(task_run):
100                valid_ids.append(task_run.id)
101
102        # Shuffle and split by split percentage
103        random.shuffle(valid_ids)
104        split_contents = {}
105        start_idx = 0
106        remaining_items = len(valid_ids)
107
108        # Handle all splits except the last one
109        for split in splits[:-1]:
110            split_size = round(len(valid_ids) * split.percentage)
111            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
112            start_idx += split_size
113            remaining_items -= split_size
114
115        # Last split gets all remaining items (for rounding)
116        if splits:
117            split_contents[splits[-1].name] = valid_ids[start_idx:]
118
119        return split_contents
DatasetFilter = typing.Callable[[ForwardRef('TaskRun')], bool]
def AllDatasetFilter(_: kiln_ai.datamodel.TaskRun) -> bool:
16def AllDatasetFilter(_: "TaskRun") -> bool:
17    return True
def HighRatingDatasetFilter(task_run: kiln_ai.datamodel.TaskRun) -> bool:
20def HighRatingDatasetFilter(task_run: "TaskRun") -> bool:
21    if task_run.output is None or task_run.output.rating is None:
22        return False
23    return task_run.output.rating.is_high_quality()
class DatasetSplitDefinition(pydantic.main.BaseModel):
26class DatasetSplitDefinition(BaseModel):
27    """
28    A definition of a split in a dataset.
29
30    Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
31    """
32
33    name: str = NAME_FIELD
34    description: str | None = Field(
35        default=None,
36        description="A description of the dataset for you and your team. Not used in training.",
37    )
38    percentage: float = Field(
39        ge=0.0,
40        le=1.0,
41        description="The percentage of the dataset that this split represents (between 0 and 1).",
42    )

A definition of a split in a dataset.

Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)

name: str
description: str | None
percentage: float
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_fields: ClassVar[Dict[str, pydantic.fields.FieldInfo]] = {'name': FieldInfo(annotation=str, required=True, description='A name for this entity.', metadata=[MinLen(min_length=1), MaxLen(max_length=120), _PydanticGeneralMetadata(pattern='^[A-Za-z0-9 _-]+$')]), 'description': FieldInfo(annotation=Union[str, NoneType], required=False, default=None, description='A description of the dataset for you and your team. Not used in training.'), 'percentage': FieldInfo(annotation=float, required=True, description='The percentage of the dataset that this split represents (between 0 and 1).', metadata=[Ge(ge=0.0), Le(le=1.0)])}

Metadata about the fields defined on the model, mapping of field names to [FieldInfo][pydantic.fields.FieldInfo] objects.

This replaces Model.__fields__ from Pydantic V1.

model_computed_fields: ClassVar[Dict[str, pydantic.fields.ComputedFieldInfo]] = {}

A dictionary of computed field names and their corresponding ComputedFieldInfo objects.

class DatasetSplit(kiln_ai.datamodel.basemodel.KilnParentedModel):
 45class DatasetSplit(KilnParentedModel):
 46    """
 47    A collection of task runs, with optional splits (train, test, validation)
 48
 49    You probably want to use DatasetSplit class from the datamodel module, which is has relationships to the Task and TaskRun models.
 50    """
 51
 52    # TODO: NAME_FIELD
 53    name: str
 54    description: str | None = Field(
 55        default=None,
 56        description="A description of the dataset for you and your team. Not used in training.",
 57    )
 58    splits: list[DatasetSplitDefinition] = Field(
 59        default_factory=list,
 60        description="The splits in the dataset.",
 61    )
 62    split_contents: dict[str, list[str]] = Field(
 63        description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
 64    )
 65
 66    @model_validator(mode="after")
 67    def validate_split_percentages(self) -> "DatasetSplit":
 68        total = sum(split.percentage for split in self.splits)
 69        if not math.isclose(total, 1.0, rel_tol=1e-9):
 70            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
 71        return self
 72
 73    @classmethod
 74    def from_task(
 75        cls,
 76        name: str,
 77        task: "Task",
 78        splits: list[DatasetSplitDefinition],
 79        filter: DatasetFilter = AllDatasetFilter,
 80        description: str | None = None,
 81    ):
 82        split_contents = cls.build_split_contents(task, splits, filter)
 83        return cls(
 84            parent=task,
 85            name=name,
 86            description=description,
 87            splits=splits,
 88            split_contents=split_contents,
 89        )
 90
 91    @classmethod
 92    def build_split_contents(
 93        cls,
 94        task: "Task",
 95        splits: list[DatasetSplitDefinition],
 96        filter: DatasetFilter,
 97    ) -> dict[str, list[str]]:
 98        valid_ids = []
 99        for task_run in task.runs():
100            if filter(task_run):
101                valid_ids.append(task_run.id)
102
103        # Shuffle and split by split percentage
104        random.shuffle(valid_ids)
105        split_contents = {}
106        start_idx = 0
107        remaining_items = len(valid_ids)
108
109        # Handle all splits except the last one
110        for split in splits[:-1]:
111            split_size = round(len(valid_ids) * split.percentage)
112            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
113            start_idx += split_size
114            remaining_items -= split_size
115
116        # Last split gets all remaining items (for rounding)
117        if splits:
118            split_contents[splits[-1].name] = valid_ids[start_idx:]
119
120        return split_contents

A collection of task runs, with optional splits (train, test, validation)

You probably want to use DatasetSplit class from the datamodel module, which is has relationships to the Task and TaskRun models.

name: str
description: str | None
splits: list[DatasetSplitDefinition]
split_contents: dict[str, list[str]]
@model_validator(mode='after')
def validate_split_percentages(self) -> DatasetSplit:
66    @model_validator(mode="after")
67    def validate_split_percentages(self) -> "DatasetSplit":
68        total = sum(split.percentage for split in self.splits)
69        if not math.isclose(total, 1.0, rel_tol=1e-9):
70            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
71        return self
@classmethod
def from_task( cls, name: str, task: kiln_ai.datamodel.Task, splits: list[DatasetSplitDefinition], filter: Callable[[kiln_ai.datamodel.TaskRun], bool] = <function AllDatasetFilter>, description: str | None = None):
73    @classmethod
74    def from_task(
75        cls,
76        name: str,
77        task: "Task",
78        splits: list[DatasetSplitDefinition],
79        filter: DatasetFilter = AllDatasetFilter,
80        description: str | None = None,
81    ):
82        split_contents = cls.build_split_contents(task, splits, filter)
83        return cls(
84            parent=task,
85            name=name,
86            description=description,
87            splits=splits,
88            split_contents=split_contents,
89        )
@classmethod
def build_split_contents( cls, task: kiln_ai.datamodel.Task, splits: list[DatasetSplitDefinition], filter: Callable[[kiln_ai.datamodel.TaskRun], bool]) -> dict[str, list[str]]:
 91    @classmethod
 92    def build_split_contents(
 93        cls,
 94        task: "Task",
 95        splits: list[DatasetSplitDefinition],
 96        filter: DatasetFilter,
 97    ) -> dict[str, list[str]]:
 98        valid_ids = []
 99        for task_run in task.runs():
100            if filter(task_run):
101                valid_ids.append(task_run.id)
102
103        # Shuffle and split by split percentage
104        random.shuffle(valid_ids)
105        split_contents = {}
106        start_idx = 0
107        remaining_items = len(valid_ids)
108
109        # Handle all splits except the last one
110        for split in splits[:-1]:
111            split_size = round(len(valid_ids) * split.percentage)
112            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
113            start_idx += split_size
114            remaining_items -= split_size
115
116        # Last split gets all remaining items (for rounding)
117        if splits:
118            split_contents[splits[-1].name] = valid_ids[start_idx:]
119
120        return split_contents
def relationship_name() -> str:
357        def relationship_name_method() -> str:
358            return relationship_name
def parent_type() -> Type[kiln_ai.datamodel.basemodel.KilnParentModel]:
350        def parent_class_method() -> Type[KilnParentModel]:
351            return cls
model_config = {'validate_assignment': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_fields: ClassVar[Dict[str, pydantic.fields.FieldInfo]] = {'v': FieldInfo(annotation=int, required=False, default=1), 'id': FieldInfo(annotation=Union[str, NoneType], required=False, default_factory=<lambda>), 'path': FieldInfo(annotation=Union[Path, NoneType], required=False, default=None), 'created_at': FieldInfo(annotation=datetime, required=False, default_factory=builtin_function_or_method), 'created_by': FieldInfo(annotation=str, required=False, default_factory=<lambda>), 'name': FieldInfo(annotation=str, required=True), 'description': FieldInfo(annotation=Union[str, NoneType], required=False, default=None, description='A description of the dataset for you and your team. Not used in training.'), 'splits': FieldInfo(annotation=list[DatasetSplitDefinition], required=False, default_factory=list, description='The splits in the dataset.'), 'split_contents': FieldInfo(annotation=dict[str, list[str]], required=True, description='The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.')}

Metadata about the fields defined on the model, mapping of field names to [FieldInfo][pydantic.fields.FieldInfo] objects.

This replaces Model.__fields__ from Pydantic V1.

model_computed_fields: ClassVar[Dict[str, pydantic.fields.ComputedFieldInfo]] = {'model_type': ComputedFieldInfo(wrapped_property=<property object>, return_type=<class 'str'>, alias=None, alias_priority=None, title=None, field_title_generator=None, description=None, deprecated=None, examples=None, json_schema_extra=None, repr=True)}

A dictionary of computed field names and their corresponding ComputedFieldInfo objects.

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
124                    def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
125                        """We need to both initialize private attributes and call the user-defined model_post_init
126                        method.
127                        """
128                        init_private_attributes(self, context)
129                        original_model_post_init(self, context)

We need to both initialize private attributes and call the user-defined model_post_init method.