kiln_ai.datamodel
See our docs for details about our datamodel: https://kiln-ai.github.io/Kiln/kiln_core_docs/kiln_ai.html
1""" 2See our docs for details about our datamodel: https://kiln-ai.github.io/Kiln/kiln_core_docs/kiln_ai.html 3""" 4 5from __future__ import annotations 6 7import json 8import math 9import random 10from enum import Enum, IntEnum 11from typing import TYPE_CHECKING, Callable, Dict, List, Type, Union 12 13import jsonschema 14import jsonschema.exceptions 15from pydantic import ( 16 BaseModel, 17 Field, 18 ValidationInfo, 19 model_validator, 20) 21from typing_extensions import Self 22 23from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str 24 25from .basemodel import ( 26 ID_FIELD, 27 ID_TYPE, 28 NAME_FIELD, 29 SHORT_NAME_FIELD, 30 KilnBaseModel, 31 KilnParentedModel, 32 KilnParentModel, 33) 34from .json_schema import validate_schema 35 36if TYPE_CHECKING: 37 from . import Task 38 39 40__all__ = [ 41 "basemodel", 42 "json_schema", 43 "Task", 44 "Project", 45 "TaskRun", 46 "TaskOutput", 47 "TaskOutputRating", 48 "Priority", 49 "DataSource", 50 "DataSourceType", 51 "DataSourceProperty", 52 "Finetune", 53 "FineTuneStatusType", 54 "TaskOutputRatingType", 55 "TaskRequirement", 56 "TaskDeterminism", 57 "DatasetSplitDefinition", 58 "DatasetSplit", 59 "RequirementRating", 60 "TaskRequirement", 61 "strict_mode", 62 "set_strict_mode", 63 "Prompt", 64] 65 66 67# We want to be hard on ourselves for data completeness generated by the Kiln App, but don't want to make it hard for users to use the datamodel/library. 68# Strict mode enables extra validations that we want to enforce in Kiln App (and any other client that wants best practices), but not in the library (unless they opt in) 69_strict_mode: bool = False 70 71 72def strict_mode() -> bool: 73 return _strict_mode 74 75 76def set_strict_mode(value: bool) -> None: 77 global _strict_mode 78 _strict_mode = value 79 80 81class Priority(IntEnum): 82 """Defines priority levels for tasks and requirements, where P0 is highest priority.""" 83 84 p0 = 0 85 p1 = 1 86 p2 = 2 87 p3 = 3 88 89 90# Only one rating type for now, but this allows for extensibility if we want to add more in the future 91class TaskOutputRatingType(str, Enum): 92 """Defines the types of rating systems available for task outputs.""" 93 94 five_star = "five_star" 95 pass_fail = "pass_fail" 96 pass_fail_critical = "pass_fail_critical" 97 custom = "custom" 98 99 100class RequirementRating(BaseModel): 101 """Rating for a specific requirement within a task output.""" 102 103 value: float = Field( 104 description="The rating value. Interpretation depends on rating type" 105 ) 106 type: TaskOutputRatingType = Field(description="The type of rating") 107 108 109class TaskOutputRating(KilnBaseModel): 110 """ 111 A rating for a task output, including an overall rating and ratings for each requirement. 112 113 Supports: 114 - five_star: 1-5 star ratings 115 - pass_fail: boolean pass/fail (1.0 = pass, 0.0 = fail) 116 - pass_fail_critical: tri-state (1.0 = pass, 0.0 = fail, -1.0 = critical fail) 117 """ 118 119 type: TaskOutputRatingType = Field(default=TaskOutputRatingType.five_star) 120 value: float | None = Field( 121 description="The rating value. Interpretation depends on rating type:\n- five_star: 1-5 stars\n- pass_fail: 1.0 (pass) or 0.0 (fail)\n- pass_fail_critical: 1.0 (pass), 0.0 (fail), or -1.0 (critical fail)", 122 default=None, 123 ) 124 requirement_ratings: Dict[ID_TYPE, RequirementRating] = Field( 125 default={}, 126 description="The ratings of the requirements of the task.", 127 ) 128 129 # Previously we stored rating values as a dict of floats, but now we store them as RequirementRating objects. 130 @model_validator(mode="before") 131 def upgrade_old_format(cls, data: dict) -> dict: 132 if not isinstance(data, dict): 133 return data 134 135 # Check if we have the old format (dict of floats) 136 req_ratings = data.get("requirement_ratings", {}) 137 if req_ratings and all( 138 isinstance(v, (int, float)) for v in req_ratings.values() 139 ): 140 # Convert each float to a RequirementRating object 141 # all ratings are five star at the point we used this format 142 data["requirement_ratings"] = { 143 k: {"value": v, "type": TaskOutputRatingType.five_star} 144 for k, v in req_ratings.items() 145 } 146 147 return data 148 149 # Used to select high quality outputs for example selection (MultiShotPromptBuilder, etc) 150 def is_high_quality(self) -> bool: 151 if self.value is None: 152 return False 153 154 if self.type == TaskOutputRatingType.five_star: 155 return self.value >= 4 156 elif self.type == TaskOutputRatingType.pass_fail: 157 return self.value == 1.0 158 elif self.type == TaskOutputRatingType.pass_fail_critical: 159 return self.value == 1.0 160 return False 161 162 @model_validator(mode="after") 163 def validate_rating(self) -> Self: 164 if self.type not in TaskOutputRatingType: 165 raise ValueError(f"Invalid rating type: {self.type}") 166 167 # Overall rating is optional 168 if self.value is not None: 169 self._validate_rating(self.type, self.value, "overall rating") 170 171 for req_id, req_rating in self.requirement_ratings.items(): 172 self._validate_rating( 173 req_rating.type, 174 req_rating.value, 175 f"requirement rating for req ID: {req_id}", 176 ) 177 178 return self 179 180 def _validate_rating( 181 self, type: TaskOutputRatingType, rating: float | None, rating_name: str 182 ) -> None: 183 if type == TaskOutputRatingType.five_star: 184 self._validate_five_star(rating, rating_name) 185 elif type == TaskOutputRatingType.pass_fail: 186 self._validate_pass_fail(rating, rating_name) 187 elif type == TaskOutputRatingType.pass_fail_critical: 188 self._validate_pass_fail_critical(rating, rating_name) 189 190 def _validate_five_star(self, rating: float | None, rating_name: str) -> None: 191 if rating is None or not isinstance(rating, float) or not rating.is_integer(): 192 raise ValueError( 193 f"{rating_name.capitalize()} of type five_star must be an integer value (1-5)" 194 ) 195 if rating < 1 or rating > 5: 196 raise ValueError( 197 f"{rating_name.capitalize()} of type five_star must be between 1 and 5 stars" 198 ) 199 200 def _validate_pass_fail(self, rating: float | None, rating_name: str) -> None: 201 if rating is None or not isinstance(rating, float) or not rating.is_integer(): 202 raise ValueError( 203 f"{rating_name.capitalize()} of type pass_fail must be an integer value (0 or 1)" 204 ) 205 if rating not in [0, 1]: 206 raise ValueError( 207 f"{rating_name.capitalize()} of type pass_fail must be 0 (fail) or 1 (pass)" 208 ) 209 210 def _validate_pass_fail_critical( 211 self, rating: float | None, rating_name: str 212 ) -> None: 213 if rating is None or not isinstance(rating, float) or not rating.is_integer(): 214 raise ValueError( 215 f"{rating_name.capitalize()} of type pass_fail_critical must be an integer value (-1, 0, or 1)" 216 ) 217 if rating not in [-1, 0, 1]: 218 raise ValueError( 219 f"{rating_name.capitalize()} of type pass_fail_critical must be -1 (critical fail), 0 (fail), or 1 (pass)" 220 ) 221 222 223class TaskOutput(KilnBaseModel): 224 """ 225 An output for a specific task run. 226 227 Contains the actual output content, its source (human or synthetic), 228 and optional rating information. 229 """ 230 231 output: str = Field( 232 description="The output of the task. JSON formatted for structured output, plaintext for unstructured output." 233 ) 234 source: DataSource | None = Field( 235 description="The source of the output: human or synthetic.", 236 default=None, 237 ) 238 rating: TaskOutputRating | None = Field( 239 default=None, description="The rating of the output" 240 ) 241 242 def validate_output_format(self, task: Task) -> Self: 243 # validate output 244 if task.output_json_schema is not None: 245 try: 246 validate_schema(json.loads(self.output), task.output_json_schema) 247 except json.JSONDecodeError: 248 raise ValueError("Output is not a valid JSON object") 249 except jsonschema.exceptions.ValidationError as e: 250 raise ValueError(f"Output does not match task output schema: {e}") 251 return self 252 253 @model_validator(mode="after") 254 def validate_output_source(self, info: ValidationInfo) -> Self: 255 # On strict mode and not loaded from file, we validate output_source is not None. 256 # We want to be able to load any data, even if it's not perfect. But we want to create perfect data when adding new data. 257 if not strict_mode(): 258 return self 259 if self.loaded_from_file(info): 260 return self 261 if self.source is None: 262 raise ValueError("Output source is required when strict mode is enabled") 263 return self 264 265 266class FineTuneStatusType(str, Enum): 267 """ 268 The status type of a fine-tune (running, completed, failed, etc). 269 """ 270 271 unknown = "unknown" # server error 272 pending = "pending" 273 running = "running" 274 completed = "completed" 275 failed = "failed" 276 277 278class StructuredOutputMode(str, Enum): 279 """ 280 Enumeration of supported structured output modes. 281 282 - default: let the adapter decide 283 - json_schema: request json using API capabilities for json_schema 284 - function_calling: request json using API capabilities for function calling 285 - json_mode: request json using API's JSON mode, which should return valid JSON, but isn't checking/passing the schema 286 - json_instructions: append instructions to the prompt to request json matching the schema. No API capabilities are used. You should have a custom parser on these models as they will be returning strings. 287 - json_instruction_and_object: append instructions to the prompt to request json matching the schema. Also request the response as json_mode via API capabilities (returning dictionaries). 288 """ 289 290 default = "default" 291 json_schema = "json_schema" 292 function_calling = "function_calling" 293 json_mode = "json_mode" 294 json_instructions = "json_instructions" 295 json_instruction_and_object = "json_instruction_and_object" 296 297 298class Finetune(KilnParentedModel): 299 """ 300 The Kiln fine-tune datamodel. 301 302 Initially holds a reference to a training job, with needed identifiers to update the status. When complete, contains the new model ID. 303 """ 304 305 name: str = NAME_FIELD 306 description: str | None = Field( 307 default=None, 308 description="A description of the fine-tune for you and your team. Not used in training.", 309 ) 310 structured_output_mode: StructuredOutputMode | None = Field( 311 default=None, 312 description="The mode to use to train the model for structured output, if it was trained with structured output. Will determine how we call the tuned model, so we call with the matching mode.", 313 ) 314 provider: str = Field( 315 description="The provider to use for the fine-tune (e.g. 'openai')." 316 ) 317 base_model_id: str = Field( 318 description="The id of the base model to use for the fine-tune. This string relates to the provider's IDs for their own models, not Kiln IDs." 319 ) 320 provider_id: str | None = Field( 321 default=None, 322 description="The ID of the fine-tune job on the provider's side. May not be the same as the fine_tune_model_id.", 323 ) 324 fine_tune_model_id: str | None = Field( 325 default=None, 326 description="The ID of the fine-tuned model on the provider's side. May not be the same as the provider_id.", 327 ) 328 dataset_split_id: str = Field( 329 description="The ID of the dataset split to use for this fine-tune.", 330 ) 331 train_split_name: str = Field( 332 default="train", 333 description="The name of the training split to use for this fine-tune.", 334 ) 335 validation_split_name: str | None = Field( 336 default=None, 337 description="The name of the validation split to use for this fine-tune. Optional.", 338 ) 339 parameters: dict[str, str | int | float | bool] = Field( 340 default={}, 341 description="The parameters to use for this fine-tune. These are provider-specific.", 342 ) 343 system_message: str = Field( 344 description="The system message to use for this fine-tune.", 345 ) 346 latest_status: FineTuneStatusType = Field( 347 default=FineTuneStatusType.unknown, 348 description="The latest known status of this fine-tune. Not updated in real time.", 349 ) 350 properties: Dict[str, str | int | float] = Field( 351 default={}, 352 description="Properties of the fine-tune. Different providers may use different properties.", 353 ) 354 355 def parent_task(self) -> Task | None: 356 if not isinstance(self.parent, Task): 357 return None 358 return self.parent 359 360 361class DataSourceType(str, Enum): 362 """ 363 The source type of a piece of data. 364 365 Human: a human created the data 366 Synthetic: a model created the data 367 """ 368 369 human = "human" 370 synthetic = "synthetic" 371 372 373class DataSourceProperty(BaseModel): 374 """ 375 Defines a property that can be associated with a data source. 376 377 Includes validation rules for when properties are required or not allowed 378 based on the data source type. 379 """ 380 381 name: str 382 type: Type[Union[str, int, float]] 383 required_for: List[DataSourceType] = [] 384 not_allowed_for: List[DataSourceType] = [] 385 386 387class DataSource(BaseModel): 388 """ 389 Represents the origin of data, either human or synthetic, with associated properties. 390 391 Properties vary based on the source type - for synthetic sources this includes 392 model information, for human sources this includes creator information. 393 """ 394 395 type: DataSourceType 396 properties: Dict[str, str | int | float] = Field( 397 default={}, 398 description="Properties describing the data source. For synthetic things like model. For human, the human's name.", 399 ) 400 401 _data_source_properties = [ 402 DataSourceProperty( 403 name="created_by", 404 type=str, 405 required_for=[DataSourceType.human], 406 not_allowed_for=[DataSourceType.synthetic], 407 ), 408 DataSourceProperty( 409 name="model_name", 410 type=str, 411 required_for=[DataSourceType.synthetic], 412 not_allowed_for=[DataSourceType.human], 413 ), 414 DataSourceProperty( 415 name="model_provider", 416 type=str, 417 required_for=[DataSourceType.synthetic], 418 not_allowed_for=[DataSourceType.human], 419 ), 420 DataSourceProperty( 421 name="adapter_name", 422 type=str, 423 required_for=[DataSourceType.synthetic], 424 not_allowed_for=[DataSourceType.human], 425 ), 426 DataSourceProperty( 427 name="prompt_builder_name", 428 type=str, 429 not_allowed_for=[DataSourceType.human], 430 ), 431 DataSourceProperty( 432 # Optional prompt builders with IDs (like static prompts) 433 name="prompt_id", 434 type=str, 435 not_allowed_for=[DataSourceType.human], 436 ), 437 ] 438 439 @model_validator(mode="after") 440 def validate_type(self) -> "DataSource": 441 if self.type not in DataSourceType: 442 raise ValueError(f"Invalid data source type: {self.type}") 443 return self 444 445 @model_validator(mode="after") 446 def validate_properties(self) -> "DataSource": 447 for prop in self._data_source_properties: 448 # Check the property type is correct 449 if prop.name in self.properties: 450 if not isinstance(self.properties[prop.name], prop.type): 451 raise ValueError( 452 f"'{prop.name}' must be of type {prop.type.__name__} for {self.type} data source" 453 ) 454 # Check the property is required for the data source type 455 if self.type in prop.required_for: 456 if prop.name not in self.properties: 457 raise ValueError( 458 f"'{prop.name}' is required for {self.type} data source" 459 ) 460 # Check the property is not allowed for the data source type 461 elif self.type in prop.not_allowed_for and prop.name in self.properties: 462 raise ValueError( 463 f"'{prop.name}' is not allowed for {self.type} data source" 464 ) 465 return self 466 467 @model_validator(mode="after") 468 def validate_no_empty_properties(self) -> Self: 469 for prop, value in self.properties.items(): 470 if isinstance(value, str) and value == "": 471 raise ValueError( 472 f"Property '{prop}' must be a non-empty string for {self.type} data source" 473 ) 474 return self 475 476 477class TaskRun(KilnParentedModel): 478 """ 479 Represents a single execution of a Task. 480 481 Contains the input used, its source, the output produced, and optional 482 repair information if the output needed correction. 483 """ 484 485 input: str = Field( 486 description="The inputs to the task. JSON formatted for structured input, plaintext for unstructured input." 487 ) 488 input_source: DataSource | None = Field( 489 default=None, description="The source of the input: human or synthetic." 490 ) 491 492 output: TaskOutput = Field(description="The output of the task run.") 493 repair_instructions: str | None = Field( 494 default=None, 495 description="Instructions for fixing the output. Should define what is wrong, and how to fix it. Will be used by models for both generating a fixed output, and evaluating future models.", 496 ) 497 repaired_output: TaskOutput | None = Field( 498 default=None, 499 description="An version of the output with issues fixed. This must be a 'fixed' version of the existing output, and not an entirely new output. If you wish to generate an ideal curatorial output for this task unrelated to this output, generate a new TaskOutput with type 'human' instead of using this field.", 500 ) 501 intermediate_outputs: Dict[str, str] | None = Field( 502 default=None, 503 description="Intermediate outputs from the task run. Keys are the names of the intermediate output steps (cot=chain of thought, etc), values are the output data.", 504 ) 505 tags: List[str] = Field( 506 default=[], 507 description="Tags for the task run. Tags are used to categorize task runs for filtering and reporting.", 508 ) 509 510 def parent_task(self) -> Task | None: 511 if not isinstance(self.parent, Task): 512 return None 513 return self.parent 514 515 @model_validator(mode="after") 516 def validate_input_format(self, info: ValidationInfo) -> Self: 517 # Don't validate if loading from file (not new). Too slow. 518 # We don't allow changing task schema, so this is redundant validation. 519 # Note: we still validate if editing a loaded model 520 if self.loading_from_file(info): 521 # Consider loading an existing model as validated. 522 self._last_validated_input = self.input 523 return self 524 525 # Don't validate if input has not changed. Too slow to run this every time. 526 if ( 527 hasattr(self, "_last_validated_input") 528 and self.input == self._last_validated_input 529 ): 530 return self 531 532 task = self.parent_task() 533 if task is None: 534 # don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving) 535 return self 536 537 # validate output 538 if task.input_json_schema is not None: 539 try: 540 validate_schema(json.loads(self.input), task.input_json_schema) 541 except json.JSONDecodeError: 542 raise ValueError("Input is not a valid JSON object") 543 except jsonschema.exceptions.ValidationError as e: 544 raise ValueError(f"Input does not match task input schema: {e}") 545 self._last_validated_input = self.input 546 return self 547 548 @model_validator(mode="after") 549 def validate_output_format(self, info: ValidationInfo) -> Self: 550 # Don't validate if loading from file (not new). Too slow. 551 # Note: we still validate if editing a loaded model's output. 552 if self.loading_from_file(info): 553 # Consider loading an existing model as validated. 554 self._last_validated_output = self.output.output if self.output else None 555 return self 556 557 # Don't validate unless output has changed since last validation. 558 # The validator is slow and costly, don't want it running when setting other fields. 559 if ( 560 hasattr(self, "_last_validated_output") 561 and self.output is not None 562 and self.output.output == self._last_validated_output 563 ): 564 return self 565 566 task = self.parent_task() 567 if task is None: 568 return self 569 570 self.output.validate_output_format(task) 571 self._last_validated_output = self.output.output if self.output else None 572 return self 573 574 @model_validator(mode="after") 575 def validate_repaired_output(self) -> Self: 576 if self.repaired_output is not None: 577 if self.repaired_output.rating is not None: 578 raise ValueError( 579 "Repaired output rating must be None. Repaired outputs are assumed to have a perfect rating, as they have been fixed." 580 ) 581 if self.repair_instructions is None and self.repaired_output is not None: 582 raise ValueError( 583 "Repair instructions are required if providing a repaired output." 584 ) 585 if self.repair_instructions is not None and self.repaired_output is None: 586 raise ValueError( 587 "A repaired output is required if providing repair instructions." 588 ) 589 return self 590 591 @model_validator(mode="after") 592 def validate_input_source(self, info: ValidationInfo) -> Self: 593 # On strict mode and not loaded from file, we validate input_source is not None. 594 # We want to be able to load any data, even if it's not perfect. But we want to create perfect data when adding new data. 595 if not strict_mode(): 596 return self 597 if self.loaded_from_file(info): 598 return self 599 if self.input_source is None: 600 raise ValueError("input_source is required when strict mode is enabled") 601 return self 602 603 @model_validator(mode="after") 604 def validate_tags(self) -> Self: 605 for tag in self.tags: 606 if not tag: 607 raise ValueError("Tags cannot be empty strings") 608 if " " in tag: 609 raise ValueError("Tags cannot contain spaces. Try underscores.") 610 611 return self 612 613 614# Define the type alias for clarity 615DatasetFilter = Callable[[TaskRun], bool] 616 617 618def AllDatasetFilter(_: TaskRun) -> bool: 619 return True 620 621 622def HighRatingDatasetFilter(task_run: TaskRun) -> bool: 623 if task_run.output is None: 624 return False 625 if task_run.repaired_output is not None: 626 # Repairs always considered high quality 627 return True 628 if task_run.output.rating is None: 629 return False 630 return task_run.output.rating.is_high_quality() 631 632 633class DatasetSplitDefinition(BaseModel): 634 """ 635 A definition of a split in a dataset. 636 637 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 638 """ 639 640 name: str = NAME_FIELD 641 description: str | None = Field( 642 default=None, 643 description="A description of the dataset for you and your team. Not used in training.", 644 ) 645 percentage: float = Field( 646 ge=0.0, 647 le=1.0, 648 description="The percentage of the dataset that this split represents (between 0 and 1).", 649 ) 650 651 652AllSplitDefinition: list[DatasetSplitDefinition] = [ 653 DatasetSplitDefinition(name="all", percentage=1.0) 654] 655Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [ 656 DatasetSplitDefinition(name="train", percentage=0.8), 657 DatasetSplitDefinition(name="test", percentage=0.2), 658] 659Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [ 660 DatasetSplitDefinition(name="train", percentage=0.6), 661 DatasetSplitDefinition(name="test", percentage=0.2), 662 DatasetSplitDefinition(name="val", percentage=0.2), 663] 664Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [ 665 DatasetSplitDefinition(name="train", percentage=0.8), 666 DatasetSplitDefinition(name="test", percentage=0.1), 667 DatasetSplitDefinition(name="val", percentage=0.1), 668] 669 670 671class DatasetSplit(KilnParentedModel): 672 """ 673 A collection of task runs, with optional splits (train, test, validation). 674 675 Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks. 676 677 Maintains a list of IDs for each split, to avoid data duplication. 678 """ 679 680 name: str = NAME_FIELD 681 description: str | None = Field( 682 default=None, 683 description="A description of the dataset for you and your team. Not used in training.", 684 ) 685 splits: list[DatasetSplitDefinition] = Field( 686 default_factory=list, 687 description="The splits in the dataset.", 688 ) 689 split_contents: dict[str, list[str]] = Field( 690 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 691 ) 692 693 @model_validator(mode="after") 694 def validate_split_percentages(self) -> "DatasetSplit": 695 total = sum(split.percentage for split in self.splits) 696 if not math.isclose(total, 1.0, rel_tol=1e-9): 697 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 698 return self 699 700 @classmethod 701 def from_task( 702 cls, 703 name: str, 704 task: "Task", 705 splits: list[DatasetSplitDefinition], 706 filter: DatasetFilter = AllDatasetFilter, 707 description: str | None = None, 708 ): 709 """ 710 Build a dataset split from a task. 711 """ 712 split_contents = cls.build_split_contents(task, splits, filter) 713 return cls( 714 parent=task, 715 name=name, 716 description=description, 717 splits=splits, 718 split_contents=split_contents, 719 ) 720 721 @classmethod 722 def build_split_contents( 723 cls, 724 task: "Task", 725 splits: list[DatasetSplitDefinition], 726 filter: DatasetFilter, 727 ) -> dict[str, list[str]]: 728 valid_ids = [] 729 for task_run in task.runs(): 730 if filter(task_run): 731 valid_ids.append(task_run.id) 732 733 # Shuffle and split by split percentage 734 random.shuffle(valid_ids) 735 split_contents = {} 736 start_idx = 0 737 remaining_items = len(valid_ids) 738 739 # Handle all splits except the last one 740 for split in splits[:-1]: 741 split_size = round(len(valid_ids) * split.percentage) 742 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 743 start_idx += split_size 744 remaining_items -= split_size 745 746 # Last split gets all remaining items (for rounding) 747 if splits: 748 split_contents[splits[-1].name] = valid_ids[start_idx:] 749 750 return split_contents 751 752 def parent_task(self) -> "Task | None": 753 # inline import to avoid circular import 754 from kiln_ai.datamodel import Task 755 756 if not isinstance(self.parent, Task): 757 return None 758 return self.parent 759 760 def missing_count(self) -> int: 761 """ 762 Returns: 763 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 764 """ 765 parent = self.parent_task() 766 if parent is None: 767 raise ValueError("DatasetSplit has no parent task") 768 769 runs = parent.runs(readonly=True) 770 all_ids = set(run.id for run in runs) 771 all_ids_in_splits = set() 772 for ids in self.split_contents.values(): 773 all_ids_in_splits.update(ids) 774 missing = all_ids_in_splits - all_ids 775 return len(missing) 776 777 778class Prompt(KilnParentedModel): 779 """ 780 A prompt for a task. 781 """ 782 783 name: str = NAME_FIELD 784 prompt: str = Field( 785 description="The prompt for the task.", 786 min_length=1, 787 ) 788 chain_of_thought_instructions: str | None = Field( 789 default=None, 790 description="Instructions for the model 'thinking' about the requirement prior to answering. Used for chain of thought style prompting. COT will not be used unless this is provided.", 791 ) 792 793 794class TaskRequirement(BaseModel): 795 """ 796 Defines a specific requirement that should be met by task outputs. 797 798 Includes an identifier, name, description, instruction for meeting the requirement, 799 priority level, and rating type (five_star, pass_fail, pass_fail_critical, custom). 800 """ 801 802 id: ID_TYPE = ID_FIELD 803 name: str = SHORT_NAME_FIELD 804 description: str | None = Field(default=None) 805 instruction: str = Field(min_length=1) 806 priority: Priority = Field(default=Priority.p2) 807 type: TaskOutputRatingType = Field(default=TaskOutputRatingType.five_star) 808 809 810class TaskDeterminism(str, Enum): 811 """ 812 Defines how strictly task outputs should match expected results. 813 814 - deterministic: Requires exact matches 815 - semantic_match: Allows different wording with same meaning 816 - flexible: Allows variation in both wording and meaning within requirements 817 """ 818 819 deterministic = "deterministic" # Expect exact match 820 semantic_match = "semantic_match" # Expect same meaning, but flexible on expression of the meaning 821 flexible = "flexible" # Flexible on semantic output. Eval should be custom based on parsing requirements. 822 823 824class Task( 825 KilnParentedModel, 826 KilnParentModel, 827 parent_of={ 828 "runs": TaskRun, 829 "dataset_splits": DatasetSplit, 830 "finetunes": Finetune, 831 "prompts": Prompt, 832 }, 833): 834 """ 835 Represents a specific task to be performed, with associated requirements and validation rules. 836 837 Contains the task definition, requirements, input/output schemas, and maintains 838 a collection of task runs. 839 """ 840 841 name: str = NAME_FIELD 842 description: str | None = Field( 843 default=None, 844 description="A description of the task for you and your team. Will not be used in prompts/training/validation.", 845 ) 846 instruction: str = Field( 847 min_length=1, 848 description="The instructions for the task. Will be used in prompts/training/validation.", 849 ) 850 requirements: List[TaskRequirement] = Field(default=[]) 851 output_json_schema: JsonObjectSchema | None = None 852 input_json_schema: JsonObjectSchema | None = None 853 thinking_instruction: str | None = Field( 854 default=None, 855 description="Instructions for the model 'thinking' about the requirement prior to answering. Used for chain of thought style prompting.", 856 ) 857 858 def output_schema(self) -> Dict | None: 859 if self.output_json_schema is None: 860 return None 861 return schema_from_json_str(self.output_json_schema) 862 863 def input_schema(self) -> Dict | None: 864 if self.input_json_schema is None: 865 return None 866 return schema_from_json_str(self.input_json_schema) 867 868 # These wrappers help for typechecking. TODO P2: fix this in KilnParentModel 869 def runs(self, readonly: bool = False) -> list[TaskRun]: 870 return super().runs(readonly=readonly) # type: ignore 871 872 def dataset_splits(self, readonly: bool = False) -> list[DatasetSplit]: 873 return super().dataset_splits(readonly=readonly) # type: ignore 874 875 def finetunes(self, readonly: bool = False) -> list[Finetune]: 876 return super().finetunes(readonly=readonly) # type: ignore 877 878 def prompts(self, readonly: bool = False) -> list[Prompt]: 879 return super().prompts(readonly=readonly) # type: ignore 880 881 882class Project(KilnParentModel, parent_of={"tasks": Task}): 883 """ 884 A collection of related tasks. 885 886 Projects organize tasks into logical groups and provide high-level descriptions 887 of the overall goals. 888 """ 889 890 name: str = NAME_FIELD 891 description: str | None = Field( 892 default=None, 893 description="A description of the project for you and your team. Will not be used in prompts/training/validation.", 894 ) 895 896 # Needed for typechecking. TODO P2: fix this in KilnParentModel 897 def tasks(self) -> list[Task]: 898 return super().tasks() # type: ignore
825class Task( 826 KilnParentedModel, 827 KilnParentModel, 828 parent_of={ 829 "runs": TaskRun, 830 "dataset_splits": DatasetSplit, 831 "finetunes": Finetune, 832 "prompts": Prompt, 833 }, 834): 835 """ 836 Represents a specific task to be performed, with associated requirements and validation rules. 837 838 Contains the task definition, requirements, input/output schemas, and maintains 839 a collection of task runs. 840 """ 841 842 name: str = NAME_FIELD 843 description: str | None = Field( 844 default=None, 845 description="A description of the task for you and your team. Will not be used in prompts/training/validation.", 846 ) 847 instruction: str = Field( 848 min_length=1, 849 description="The instructions for the task. Will be used in prompts/training/validation.", 850 ) 851 requirements: List[TaskRequirement] = Field(default=[]) 852 output_json_schema: JsonObjectSchema | None = None 853 input_json_schema: JsonObjectSchema | None = None 854 thinking_instruction: str | None = Field( 855 default=None, 856 description="Instructions for the model 'thinking' about the requirement prior to answering. Used for chain of thought style prompting.", 857 ) 858 859 def output_schema(self) -> Dict | None: 860 if self.output_json_schema is None: 861 return None 862 return schema_from_json_str(self.output_json_schema) 863 864 def input_schema(self) -> Dict | None: 865 if self.input_json_schema is None: 866 return None 867 return schema_from_json_str(self.input_json_schema) 868 869 # These wrappers help for typechecking. TODO P2: fix this in KilnParentModel 870 def runs(self, readonly: bool = False) -> list[TaskRun]: 871 return super().runs(readonly=readonly) # type: ignore 872 873 def dataset_splits(self, readonly: bool = False) -> list[DatasetSplit]: 874 return super().dataset_splits(readonly=readonly) # type: ignore 875 876 def finetunes(self, readonly: bool = False) -> list[Finetune]: 877 return super().finetunes(readonly=readonly) # type: ignore 878 879 def prompts(self, readonly: bool = False) -> list[Prompt]: 880 return super().prompts(readonly=readonly) # type: ignore
Represents a specific task to be performed, with associated requirements and validation rules.
Contains the task definition, requirements, input/output schemas, and maintains a collection of task runs.
418 def child_method(self, readonly: bool = False) -> list[child_class]: 419 return child_class.all_children_of_parent_path(self.path, readonly=readonly)
The type of the None singleton.
418 def child_method(self, readonly: bool = False) -> list[child_class]: 419 return child_class.all_children_of_parent_path(self.path, readonly=readonly)
The type of the None singleton.
418 def child_method(self, readonly: bool = False) -> list[child_class]: 419 return child_class.all_children_of_parent_path(self.path, readonly=readonly)
The type of the None singleton.
418 def child_method(self, readonly: bool = False) -> list[child_class]: 419 return child_class.all_children_of_parent_path(self.path, readonly=readonly)
The type of the None singleton.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
122 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 123 """We need to both initialize private attributes and call the user-defined model_post_init 124 method. 125 """ 126 init_private_attributes(self, context) 127 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.
Inherited Members
883class Project(KilnParentModel, parent_of={"tasks": Task}): 884 """ 885 A collection of related tasks. 886 887 Projects organize tasks into logical groups and provide high-level descriptions 888 of the overall goals. 889 """ 890 891 name: str = NAME_FIELD 892 description: str | None = Field( 893 default=None, 894 description="A description of the project for you and your team. Will not be used in prompts/training/validation.", 895 ) 896 897 # Needed for typechecking. TODO P2: fix this in KilnParentModel 898 def tasks(self) -> list[Task]: 899 return super().tasks() # type: ignore
A collection of related tasks.
Projects organize tasks into logical groups and provide high-level descriptions of the overall goals.
418 def child_method(self, readonly: bool = False) -> list[child_class]: 419 return child_class.all_children_of_parent_path(self.path, readonly=readonly)
The type of the None singleton.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
122 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 123 """We need to both initialize private attributes and call the user-defined model_post_init 124 method. 125 """ 126 init_private_attributes(self, context) 127 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.
Inherited Members
478class TaskRun(KilnParentedModel): 479 """ 480 Represents a single execution of a Task. 481 482 Contains the input used, its source, the output produced, and optional 483 repair information if the output needed correction. 484 """ 485 486 input: str = Field( 487 description="The inputs to the task. JSON formatted for structured input, plaintext for unstructured input." 488 ) 489 input_source: DataSource | None = Field( 490 default=None, description="The source of the input: human or synthetic." 491 ) 492 493 output: TaskOutput = Field(description="The output of the task run.") 494 repair_instructions: str | None = Field( 495 default=None, 496 description="Instructions for fixing the output. Should define what is wrong, and how to fix it. Will be used by models for both generating a fixed output, and evaluating future models.", 497 ) 498 repaired_output: TaskOutput | None = Field( 499 default=None, 500 description="An version of the output with issues fixed. This must be a 'fixed' version of the existing output, and not an entirely new output. If you wish to generate an ideal curatorial output for this task unrelated to this output, generate a new TaskOutput with type 'human' instead of using this field.", 501 ) 502 intermediate_outputs: Dict[str, str] | None = Field( 503 default=None, 504 description="Intermediate outputs from the task run. Keys are the names of the intermediate output steps (cot=chain of thought, etc), values are the output data.", 505 ) 506 tags: List[str] = Field( 507 default=[], 508 description="Tags for the task run. Tags are used to categorize task runs for filtering and reporting.", 509 ) 510 511 def parent_task(self) -> Task | None: 512 if not isinstance(self.parent, Task): 513 return None 514 return self.parent 515 516 @model_validator(mode="after") 517 def validate_input_format(self, info: ValidationInfo) -> Self: 518 # Don't validate if loading from file (not new). Too slow. 519 # We don't allow changing task schema, so this is redundant validation. 520 # Note: we still validate if editing a loaded model 521 if self.loading_from_file(info): 522 # Consider loading an existing model as validated. 523 self._last_validated_input = self.input 524 return self 525 526 # Don't validate if input has not changed. Too slow to run this every time. 527 if ( 528 hasattr(self, "_last_validated_input") 529 and self.input == self._last_validated_input 530 ): 531 return self 532 533 task = self.parent_task() 534 if task is None: 535 # don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving) 536 return self 537 538 # validate output 539 if task.input_json_schema is not None: 540 try: 541 validate_schema(json.loads(self.input), task.input_json_schema) 542 except json.JSONDecodeError: 543 raise ValueError("Input is not a valid JSON object") 544 except jsonschema.exceptions.ValidationError as e: 545 raise ValueError(f"Input does not match task input schema: {e}") 546 self._last_validated_input = self.input 547 return self 548 549 @model_validator(mode="after") 550 def validate_output_format(self, info: ValidationInfo) -> Self: 551 # Don't validate if loading from file (not new). Too slow. 552 # Note: we still validate if editing a loaded model's output. 553 if self.loading_from_file(info): 554 # Consider loading an existing model as validated. 555 self._last_validated_output = self.output.output if self.output else None 556 return self 557 558 # Don't validate unless output has changed since last validation. 559 # The validator is slow and costly, don't want it running when setting other fields. 560 if ( 561 hasattr(self, "_last_validated_output") 562 and self.output is not None 563 and self.output.output == self._last_validated_output 564 ): 565 return self 566 567 task = self.parent_task() 568 if task is None: 569 return self 570 571 self.output.validate_output_format(task) 572 self._last_validated_output = self.output.output if self.output else None 573 return self 574 575 @model_validator(mode="after") 576 def validate_repaired_output(self) -> Self: 577 if self.repaired_output is not None: 578 if self.repaired_output.rating is not None: 579 raise ValueError( 580 "Repaired output rating must be None. Repaired outputs are assumed to have a perfect rating, as they have been fixed." 581 ) 582 if self.repair_instructions is None and self.repaired_output is not None: 583 raise ValueError( 584 "Repair instructions are required if providing a repaired output." 585 ) 586 if self.repair_instructions is not None and self.repaired_output is None: 587 raise ValueError( 588 "A repaired output is required if providing repair instructions." 589 ) 590 return self 591 592 @model_validator(mode="after") 593 def validate_input_source(self, info: ValidationInfo) -> Self: 594 # On strict mode and not loaded from file, we validate input_source is not None. 595 # We want to be able to load any data, even if it's not perfect. But we want to create perfect data when adding new data. 596 if not strict_mode(): 597 return self 598 if self.loaded_from_file(info): 599 return self 600 if self.input_source is None: 601 raise ValueError("input_source is required when strict mode is enabled") 602 return self 603 604 @model_validator(mode="after") 605 def validate_tags(self) -> Self: 606 for tag in self.tags: 607 if not tag: 608 raise ValueError("Tags cannot be empty strings") 609 if " " in tag: 610 raise ValueError("Tags cannot contain spaces. Try underscores.") 611 612 return self
Represents a single execution of a Task.
Contains the input used, its source, the output produced, and optional repair information if the output needed correction.
516 @model_validator(mode="after") 517 def validate_input_format(self, info: ValidationInfo) -> Self: 518 # Don't validate if loading from file (not new). Too slow. 519 # We don't allow changing task schema, so this is redundant validation. 520 # Note: we still validate if editing a loaded model 521 if self.loading_from_file(info): 522 # Consider loading an existing model as validated. 523 self._last_validated_input = self.input 524 return self 525 526 # Don't validate if input has not changed. Too slow to run this every time. 527 if ( 528 hasattr(self, "_last_validated_input") 529 and self.input == self._last_validated_input 530 ): 531 return self 532 533 task = self.parent_task() 534 if task is None: 535 # don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving) 536 return self 537 538 # validate output 539 if task.input_json_schema is not None: 540 try: 541 validate_schema(json.loads(self.input), task.input_json_schema) 542 except json.JSONDecodeError: 543 raise ValueError("Input is not a valid JSON object") 544 except jsonschema.exceptions.ValidationError as e: 545 raise ValueError(f"Input does not match task input schema: {e}") 546 self._last_validated_input = self.input 547 return self
549 @model_validator(mode="after") 550 def validate_output_format(self, info: ValidationInfo) -> Self: 551 # Don't validate if loading from file (not new). Too slow. 552 # Note: we still validate if editing a loaded model's output. 553 if self.loading_from_file(info): 554 # Consider loading an existing model as validated. 555 self._last_validated_output = self.output.output if self.output else None 556 return self 557 558 # Don't validate unless output has changed since last validation. 559 # The validator is slow and costly, don't want it running when setting other fields. 560 if ( 561 hasattr(self, "_last_validated_output") 562 and self.output is not None 563 and self.output.output == self._last_validated_output 564 ): 565 return self 566 567 task = self.parent_task() 568 if task is None: 569 return self 570 571 self.output.validate_output_format(task) 572 self._last_validated_output = self.output.output if self.output else None 573 return self
575 @model_validator(mode="after") 576 def validate_repaired_output(self) -> Self: 577 if self.repaired_output is not None: 578 if self.repaired_output.rating is not None: 579 raise ValueError( 580 "Repaired output rating must be None. Repaired outputs are assumed to have a perfect rating, as they have been fixed." 581 ) 582 if self.repair_instructions is None and self.repaired_output is not None: 583 raise ValueError( 584 "Repair instructions are required if providing a repaired output." 585 ) 586 if self.repair_instructions is not None and self.repaired_output is None: 587 raise ValueError( 588 "A repaired output is required if providing repair instructions." 589 ) 590 return self
592 @model_validator(mode="after") 593 def validate_input_source(self, info: ValidationInfo) -> Self: 594 # On strict mode and not loaded from file, we validate input_source is not None. 595 # We want to be able to load any data, even if it's not perfect. But we want to create perfect data when adding new data. 596 if not strict_mode(): 597 return self 598 if self.loaded_from_file(info): 599 return self 600 if self.input_source is None: 601 raise ValueError("input_source is required when strict mode is enabled") 602 return self
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
122 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 123 """We need to both initialize private attributes and call the user-defined model_post_init 124 method. 125 """ 126 init_private_attributes(self, context) 127 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.
Inherited Members
224class TaskOutput(KilnBaseModel): 225 """ 226 An output for a specific task run. 227 228 Contains the actual output content, its source (human or synthetic), 229 and optional rating information. 230 """ 231 232 output: str = Field( 233 description="The output of the task. JSON formatted for structured output, plaintext for unstructured output." 234 ) 235 source: DataSource | None = Field( 236 description="The source of the output: human or synthetic.", 237 default=None, 238 ) 239 rating: TaskOutputRating | None = Field( 240 default=None, description="The rating of the output" 241 ) 242 243 def validate_output_format(self, task: Task) -> Self: 244 # validate output 245 if task.output_json_schema is not None: 246 try: 247 validate_schema(json.loads(self.output), task.output_json_schema) 248 except json.JSONDecodeError: 249 raise ValueError("Output is not a valid JSON object") 250 except jsonschema.exceptions.ValidationError as e: 251 raise ValueError(f"Output does not match task output schema: {e}") 252 return self 253 254 @model_validator(mode="after") 255 def validate_output_source(self, info: ValidationInfo) -> Self: 256 # On strict mode and not loaded from file, we validate output_source is not None. 257 # We want to be able to load any data, even if it's not perfect. But we want to create perfect data when adding new data. 258 if not strict_mode(): 259 return self 260 if self.loaded_from_file(info): 261 return self 262 if self.source is None: 263 raise ValueError("Output source is required when strict mode is enabled") 264 return self
An output for a specific task run.
Contains the actual output content, its source (human or synthetic), and optional rating information.
243 def validate_output_format(self, task: Task) -> Self: 244 # validate output 245 if task.output_json_schema is not None: 246 try: 247 validate_schema(json.loads(self.output), task.output_json_schema) 248 except json.JSONDecodeError: 249 raise ValueError("Output is not a valid JSON object") 250 except jsonschema.exceptions.ValidationError as e: 251 raise ValueError(f"Output does not match task output schema: {e}") 252 return self
254 @model_validator(mode="after") 255 def validate_output_source(self, info: ValidationInfo) -> Self: 256 # On strict mode and not loaded from file, we validate output_source is not None. 257 # We want to be able to load any data, even if it's not perfect. But we want to create perfect data when adding new data. 258 if not strict_mode(): 259 return self 260 if self.loaded_from_file(info): 261 return self 262 if self.source is None: 263 raise ValueError("Output source is required when strict mode is enabled") 264 return self
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
122 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 123 """We need to both initialize private attributes and call the user-defined model_post_init 124 method. 125 """ 126 init_private_attributes(self, context) 127 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.
110class TaskOutputRating(KilnBaseModel): 111 """ 112 A rating for a task output, including an overall rating and ratings for each requirement. 113 114 Supports: 115 - five_star: 1-5 star ratings 116 - pass_fail: boolean pass/fail (1.0 = pass, 0.0 = fail) 117 - pass_fail_critical: tri-state (1.0 = pass, 0.0 = fail, -1.0 = critical fail) 118 """ 119 120 type: TaskOutputRatingType = Field(default=TaskOutputRatingType.five_star) 121 value: float | None = Field( 122 description="The rating value. Interpretation depends on rating type:\n- five_star: 1-5 stars\n- pass_fail: 1.0 (pass) or 0.0 (fail)\n- pass_fail_critical: 1.0 (pass), 0.0 (fail), or -1.0 (critical fail)", 123 default=None, 124 ) 125 requirement_ratings: Dict[ID_TYPE, RequirementRating] = Field( 126 default={}, 127 description="The ratings of the requirements of the task.", 128 ) 129 130 # Previously we stored rating values as a dict of floats, but now we store them as RequirementRating objects. 131 @model_validator(mode="before") 132 def upgrade_old_format(cls, data: dict) -> dict: 133 if not isinstance(data, dict): 134 return data 135 136 # Check if we have the old format (dict of floats) 137 req_ratings = data.get("requirement_ratings", {}) 138 if req_ratings and all( 139 isinstance(v, (int, float)) for v in req_ratings.values() 140 ): 141 # Convert each float to a RequirementRating object 142 # all ratings are five star at the point we used this format 143 data["requirement_ratings"] = { 144 k: {"value": v, "type": TaskOutputRatingType.five_star} 145 for k, v in req_ratings.items() 146 } 147 148 return data 149 150 # Used to select high quality outputs for example selection (MultiShotPromptBuilder, etc) 151 def is_high_quality(self) -> bool: 152 if self.value is None: 153 return False 154 155 if self.type == TaskOutputRatingType.five_star: 156 return self.value >= 4 157 elif self.type == TaskOutputRatingType.pass_fail: 158 return self.value == 1.0 159 elif self.type == TaskOutputRatingType.pass_fail_critical: 160 return self.value == 1.0 161 return False 162 163 @model_validator(mode="after") 164 def validate_rating(self) -> Self: 165 if self.type not in TaskOutputRatingType: 166 raise ValueError(f"Invalid rating type: {self.type}") 167 168 # Overall rating is optional 169 if self.value is not None: 170 self._validate_rating(self.type, self.value, "overall rating") 171 172 for req_id, req_rating in self.requirement_ratings.items(): 173 self._validate_rating( 174 req_rating.type, 175 req_rating.value, 176 f"requirement rating for req ID: {req_id}", 177 ) 178 179 return self 180 181 def _validate_rating( 182 self, type: TaskOutputRatingType, rating: float | None, rating_name: str 183 ) -> None: 184 if type == TaskOutputRatingType.five_star: 185 self._validate_five_star(rating, rating_name) 186 elif type == TaskOutputRatingType.pass_fail: 187 self._validate_pass_fail(rating, rating_name) 188 elif type == TaskOutputRatingType.pass_fail_critical: 189 self._validate_pass_fail_critical(rating, rating_name) 190 191 def _validate_five_star(self, rating: float | None, rating_name: str) -> None: 192 if rating is None or not isinstance(rating, float) or not rating.is_integer(): 193 raise ValueError( 194 f"{rating_name.capitalize()} of type five_star must be an integer value (1-5)" 195 ) 196 if rating < 1 or rating > 5: 197 raise ValueError( 198 f"{rating_name.capitalize()} of type five_star must be between 1 and 5 stars" 199 ) 200 201 def _validate_pass_fail(self, rating: float | None, rating_name: str) -> None: 202 if rating is None or not isinstance(rating, float) or not rating.is_integer(): 203 raise ValueError( 204 f"{rating_name.capitalize()} of type pass_fail must be an integer value (0 or 1)" 205 ) 206 if rating not in [0, 1]: 207 raise ValueError( 208 f"{rating_name.capitalize()} of type pass_fail must be 0 (fail) or 1 (pass)" 209 ) 210 211 def _validate_pass_fail_critical( 212 self, rating: float | None, rating_name: str 213 ) -> None: 214 if rating is None or not isinstance(rating, float) or not rating.is_integer(): 215 raise ValueError( 216 f"{rating_name.capitalize()} of type pass_fail_critical must be an integer value (-1, 0, or 1)" 217 ) 218 if rating not in [-1, 0, 1]: 219 raise ValueError( 220 f"{rating_name.capitalize()} of type pass_fail_critical must be -1 (critical fail), 0 (fail), or 1 (pass)" 221 )
A rating for a task output, including an overall rating and ratings for each requirement.
Supports:
- five_star: 1-5 star ratings
- pass_fail: boolean pass/fail (1.0 = pass, 0.0 = fail)
- pass_fail_critical: tri-state (1.0 = pass, 0.0 = fail, -1.0 = critical fail)
131 @model_validator(mode="before") 132 def upgrade_old_format(cls, data: dict) -> dict: 133 if not isinstance(data, dict): 134 return data 135 136 # Check if we have the old format (dict of floats) 137 req_ratings = data.get("requirement_ratings", {}) 138 if req_ratings and all( 139 isinstance(v, (int, float)) for v in req_ratings.values() 140 ): 141 # Convert each float to a RequirementRating object 142 # all ratings are five star at the point we used this format 143 data["requirement_ratings"] = { 144 k: {"value": v, "type": TaskOutputRatingType.five_star} 145 for k, v in req_ratings.items() 146 } 147 148 return data
151 def is_high_quality(self) -> bool: 152 if self.value is None: 153 return False 154 155 if self.type == TaskOutputRatingType.five_star: 156 return self.value >= 4 157 elif self.type == TaskOutputRatingType.pass_fail: 158 return self.value == 1.0 159 elif self.type == TaskOutputRatingType.pass_fail_critical: 160 return self.value == 1.0 161 return False
163 @model_validator(mode="after") 164 def validate_rating(self) -> Self: 165 if self.type not in TaskOutputRatingType: 166 raise ValueError(f"Invalid rating type: {self.type}") 167 168 # Overall rating is optional 169 if self.value is not None: 170 self._validate_rating(self.type, self.value, "overall rating") 171 172 for req_id, req_rating in self.requirement_ratings.items(): 173 self._validate_rating( 174 req_rating.type, 175 req_rating.value, 176 f"requirement rating for req ID: {req_id}", 177 ) 178 179 return self
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
122 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 123 """We need to both initialize private attributes and call the user-defined model_post_init 124 method. 125 """ 126 init_private_attributes(self, context) 127 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.
82class Priority(IntEnum): 83 """Defines priority levels for tasks and requirements, where P0 is highest priority.""" 84 85 p0 = 0 86 p1 = 1 87 p2 = 2 88 p3 = 3
Defines priority levels for tasks and requirements, where P0 is highest priority.
388class DataSource(BaseModel): 389 """ 390 Represents the origin of data, either human or synthetic, with associated properties. 391 392 Properties vary based on the source type - for synthetic sources this includes 393 model information, for human sources this includes creator information. 394 """ 395 396 type: DataSourceType 397 properties: Dict[str, str | int | float] = Field( 398 default={}, 399 description="Properties describing the data source. For synthetic things like model. For human, the human's name.", 400 ) 401 402 _data_source_properties = [ 403 DataSourceProperty( 404 name="created_by", 405 type=str, 406 required_for=[DataSourceType.human], 407 not_allowed_for=[DataSourceType.synthetic], 408 ), 409 DataSourceProperty( 410 name="model_name", 411 type=str, 412 required_for=[DataSourceType.synthetic], 413 not_allowed_for=[DataSourceType.human], 414 ), 415 DataSourceProperty( 416 name="model_provider", 417 type=str, 418 required_for=[DataSourceType.synthetic], 419 not_allowed_for=[DataSourceType.human], 420 ), 421 DataSourceProperty( 422 name="adapter_name", 423 type=str, 424 required_for=[DataSourceType.synthetic], 425 not_allowed_for=[DataSourceType.human], 426 ), 427 DataSourceProperty( 428 name="prompt_builder_name", 429 type=str, 430 not_allowed_for=[DataSourceType.human], 431 ), 432 DataSourceProperty( 433 # Optional prompt builders with IDs (like static prompts) 434 name="prompt_id", 435 type=str, 436 not_allowed_for=[DataSourceType.human], 437 ), 438 ] 439 440 @model_validator(mode="after") 441 def validate_type(self) -> "DataSource": 442 if self.type not in DataSourceType: 443 raise ValueError(f"Invalid data source type: {self.type}") 444 return self 445 446 @model_validator(mode="after") 447 def validate_properties(self) -> "DataSource": 448 for prop in self._data_source_properties: 449 # Check the property type is correct 450 if prop.name in self.properties: 451 if not isinstance(self.properties[prop.name], prop.type): 452 raise ValueError( 453 f"'{prop.name}' must be of type {prop.type.__name__} for {self.type} data source" 454 ) 455 # Check the property is required for the data source type 456 if self.type in prop.required_for: 457 if prop.name not in self.properties: 458 raise ValueError( 459 f"'{prop.name}' is required for {self.type} data source" 460 ) 461 # Check the property is not allowed for the data source type 462 elif self.type in prop.not_allowed_for and prop.name in self.properties: 463 raise ValueError( 464 f"'{prop.name}' is not allowed for {self.type} data source" 465 ) 466 return self 467 468 @model_validator(mode="after") 469 def validate_no_empty_properties(self) -> Self: 470 for prop, value in self.properties.items(): 471 if isinstance(value, str) and value == "": 472 raise ValueError( 473 f"Property '{prop}' must be a non-empty string for {self.type} data source" 474 ) 475 return self
Represents the origin of data, either human or synthetic, with associated properties.
Properties vary based on the source type - for synthetic sources this includes model information, for human sources this includes creator information.
446 @model_validator(mode="after") 447 def validate_properties(self) -> "DataSource": 448 for prop in self._data_source_properties: 449 # Check the property type is correct 450 if prop.name in self.properties: 451 if not isinstance(self.properties[prop.name], prop.type): 452 raise ValueError( 453 f"'{prop.name}' must be of type {prop.type.__name__} for {self.type} data source" 454 ) 455 # Check the property is required for the data source type 456 if self.type in prop.required_for: 457 if prop.name not in self.properties: 458 raise ValueError( 459 f"'{prop.name}' is required for {self.type} data source" 460 ) 461 # Check the property is not allowed for the data source type 462 elif self.type in prop.not_allowed_for and prop.name in self.properties: 463 raise ValueError( 464 f"'{prop.name}' is not allowed for {self.type} data source" 465 ) 466 return self
468 @model_validator(mode="after") 469 def validate_no_empty_properties(self) -> Self: 470 for prop, value in self.properties.items(): 471 if isinstance(value, str) and value == "": 472 raise ValueError( 473 f"Property '{prop}' must be a non-empty string for {self.type} data source" 474 ) 475 return self
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
384def init_private_attributes(self: BaseModel, context: Any, /) -> None: 385 """This function is meant to behave like a BaseModel method to initialise private attributes. 386 387 It takes context as an argument since that's what pydantic-core passes when calling it. 388 389 Args: 390 self: The BaseModel instance. 391 context: The context. 392 """ 393 if getattr(self, '__pydantic_private__', None) is None: 394 pydantic_private = {} 395 for name, private_attr in self.__private_attributes__.items(): 396 default = private_attr.get_default() 397 if default is not PydanticUndefined: 398 pydantic_private[name] = default 399 object_setattr(self, '__pydantic_private__', pydantic_private)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Args: self: The BaseModel instance. context: The context.
362class DataSourceType(str, Enum): 363 """ 364 The source type of a piece of data. 365 366 Human: a human created the data 367 Synthetic: a model created the data 368 """ 369 370 human = "human" 371 synthetic = "synthetic"
The source type of a piece of data.
Human: a human created the data Synthetic: a model created the data
374class DataSourceProperty(BaseModel): 375 """ 376 Defines a property that can be associated with a data source. 377 378 Includes validation rules for when properties are required or not allowed 379 based on the data source type. 380 """ 381 382 name: str 383 type: Type[Union[str, int, float]] 384 required_for: List[DataSourceType] = [] 385 not_allowed_for: List[DataSourceType] = []
Defines a property that can be associated with a data source.
Includes validation rules for when properties are required or not allowed based on the data source type.
299class Finetune(KilnParentedModel): 300 """ 301 The Kiln fine-tune datamodel. 302 303 Initially holds a reference to a training job, with needed identifiers to update the status. When complete, contains the new model ID. 304 """ 305 306 name: str = NAME_FIELD 307 description: str | None = Field( 308 default=None, 309 description="A description of the fine-tune for you and your team. Not used in training.", 310 ) 311 structured_output_mode: StructuredOutputMode | None = Field( 312 default=None, 313 description="The mode to use to train the model for structured output, if it was trained with structured output. Will determine how we call the tuned model, so we call with the matching mode.", 314 ) 315 provider: str = Field( 316 description="The provider to use for the fine-tune (e.g. 'openai')." 317 ) 318 base_model_id: str = Field( 319 description="The id of the base model to use for the fine-tune. This string relates to the provider's IDs for their own models, not Kiln IDs." 320 ) 321 provider_id: str | None = Field( 322 default=None, 323 description="The ID of the fine-tune job on the provider's side. May not be the same as the fine_tune_model_id.", 324 ) 325 fine_tune_model_id: str | None = Field( 326 default=None, 327 description="The ID of the fine-tuned model on the provider's side. May not be the same as the provider_id.", 328 ) 329 dataset_split_id: str = Field( 330 description="The ID of the dataset split to use for this fine-tune.", 331 ) 332 train_split_name: str = Field( 333 default="train", 334 description="The name of the training split to use for this fine-tune.", 335 ) 336 validation_split_name: str | None = Field( 337 default=None, 338 description="The name of the validation split to use for this fine-tune. Optional.", 339 ) 340 parameters: dict[str, str | int | float | bool] = Field( 341 default={}, 342 description="The parameters to use for this fine-tune. These are provider-specific.", 343 ) 344 system_message: str = Field( 345 description="The system message to use for this fine-tune.", 346 ) 347 latest_status: FineTuneStatusType = Field( 348 default=FineTuneStatusType.unknown, 349 description="The latest known status of this fine-tune. Not updated in real time.", 350 ) 351 properties: Dict[str, str | int | float] = Field( 352 default={}, 353 description="Properties of the fine-tune. Different providers may use different properties.", 354 ) 355 356 def parent_task(self) -> Task | None: 357 if not isinstance(self.parent, Task): 358 return None 359 return self.parent
The Kiln fine-tune datamodel.
Initially holds a reference to a training job, with needed identifiers to update the status. When complete, contains the new model ID.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
122 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 123 """We need to both initialize private attributes and call the user-defined model_post_init 124 method. 125 """ 126 init_private_attributes(self, context) 127 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.
Inherited Members
267class FineTuneStatusType(str, Enum): 268 """ 269 The status type of a fine-tune (running, completed, failed, etc). 270 """ 271 272 unknown = "unknown" # server error 273 pending = "pending" 274 running = "running" 275 completed = "completed" 276 failed = "failed"
The status type of a fine-tune (running, completed, failed, etc).
92class TaskOutputRatingType(str, Enum): 93 """Defines the types of rating systems available for task outputs.""" 94 95 five_star = "five_star" 96 pass_fail = "pass_fail" 97 pass_fail_critical = "pass_fail_critical" 98 custom = "custom"
Defines the types of rating systems available for task outputs.
795class TaskRequirement(BaseModel): 796 """ 797 Defines a specific requirement that should be met by task outputs. 798 799 Includes an identifier, name, description, instruction for meeting the requirement, 800 priority level, and rating type (five_star, pass_fail, pass_fail_critical, custom). 801 """ 802 803 id: ID_TYPE = ID_FIELD 804 name: str = SHORT_NAME_FIELD 805 description: str | None = Field(default=None) 806 instruction: str = Field(min_length=1) 807 priority: Priority = Field(default=Priority.p2) 808 type: TaskOutputRatingType = Field(default=TaskOutputRatingType.five_star)
Defines a specific requirement that should be met by task outputs.
Includes an identifier, name, description, instruction for meeting the requirement, priority level, and rating type (five_star, pass_fail, pass_fail_critical, custom).
811class TaskDeterminism(str, Enum): 812 """ 813 Defines how strictly task outputs should match expected results. 814 815 - deterministic: Requires exact matches 816 - semantic_match: Allows different wording with same meaning 817 - flexible: Allows variation in both wording and meaning within requirements 818 """ 819 820 deterministic = "deterministic" # Expect exact match 821 semantic_match = "semantic_match" # Expect same meaning, but flexible on expression of the meaning 822 flexible = "flexible" # Flexible on semantic output. Eval should be custom based on parsing requirements.
Defines how strictly task outputs should match expected results.
- deterministic: Requires exact matches
- semantic_match: Allows different wording with same meaning
- flexible: Allows variation in both wording and meaning within requirements
634class DatasetSplitDefinition(BaseModel): 635 """ 636 A definition of a split in a dataset. 637 638 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 639 """ 640 641 name: str = NAME_FIELD 642 description: str | None = Field( 643 default=None, 644 description="A description of the dataset for you and your team. Not used in training.", 645 ) 646 percentage: float = Field( 647 ge=0.0, 648 le=1.0, 649 description="The percentage of the dataset that this split represents (between 0 and 1).", 650 )
A definition of a split in a dataset.
Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
672class DatasetSplit(KilnParentedModel): 673 """ 674 A collection of task runs, with optional splits (train, test, validation). 675 676 Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks. 677 678 Maintains a list of IDs for each split, to avoid data duplication. 679 """ 680 681 name: str = NAME_FIELD 682 description: str | None = Field( 683 default=None, 684 description="A description of the dataset for you and your team. Not used in training.", 685 ) 686 splits: list[DatasetSplitDefinition] = Field( 687 default_factory=list, 688 description="The splits in the dataset.", 689 ) 690 split_contents: dict[str, list[str]] = Field( 691 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 692 ) 693 694 @model_validator(mode="after") 695 def validate_split_percentages(self) -> "DatasetSplit": 696 total = sum(split.percentage for split in self.splits) 697 if not math.isclose(total, 1.0, rel_tol=1e-9): 698 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 699 return self 700 701 @classmethod 702 def from_task( 703 cls, 704 name: str, 705 task: "Task", 706 splits: list[DatasetSplitDefinition], 707 filter: DatasetFilter = AllDatasetFilter, 708 description: str | None = None, 709 ): 710 """ 711 Build a dataset split from a task. 712 """ 713 split_contents = cls.build_split_contents(task, splits, filter) 714 return cls( 715 parent=task, 716 name=name, 717 description=description, 718 splits=splits, 719 split_contents=split_contents, 720 ) 721 722 @classmethod 723 def build_split_contents( 724 cls, 725 task: "Task", 726 splits: list[DatasetSplitDefinition], 727 filter: DatasetFilter, 728 ) -> dict[str, list[str]]: 729 valid_ids = [] 730 for task_run in task.runs(): 731 if filter(task_run): 732 valid_ids.append(task_run.id) 733 734 # Shuffle and split by split percentage 735 random.shuffle(valid_ids) 736 split_contents = {} 737 start_idx = 0 738 remaining_items = len(valid_ids) 739 740 # Handle all splits except the last one 741 for split in splits[:-1]: 742 split_size = round(len(valid_ids) * split.percentage) 743 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 744 start_idx += split_size 745 remaining_items -= split_size 746 747 # Last split gets all remaining items (for rounding) 748 if splits: 749 split_contents[splits[-1].name] = valid_ids[start_idx:] 750 751 return split_contents 752 753 def parent_task(self) -> "Task | None": 754 # inline import to avoid circular import 755 from kiln_ai.datamodel import Task 756 757 if not isinstance(self.parent, Task): 758 return None 759 return self.parent 760 761 def missing_count(self) -> int: 762 """ 763 Returns: 764 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 765 """ 766 parent = self.parent_task() 767 if parent is None: 768 raise ValueError("DatasetSplit has no parent task") 769 770 runs = parent.runs(readonly=True) 771 all_ids = set(run.id for run in runs) 772 all_ids_in_splits = set() 773 for ids in self.split_contents.values(): 774 all_ids_in_splits.update(ids) 775 missing = all_ids_in_splits - all_ids 776 return len(missing)
A collection of task runs, with optional splits (train, test, validation).
Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
Maintains a list of IDs for each split, to avoid data duplication.
694 @model_validator(mode="after") 695 def validate_split_percentages(self) -> "DatasetSplit": 696 total = sum(split.percentage for split in self.splits) 697 if not math.isclose(total, 1.0, rel_tol=1e-9): 698 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 699 return self
701 @classmethod 702 def from_task( 703 cls, 704 name: str, 705 task: "Task", 706 splits: list[DatasetSplitDefinition], 707 filter: DatasetFilter = AllDatasetFilter, 708 description: str | None = None, 709 ): 710 """ 711 Build a dataset split from a task. 712 """ 713 split_contents = cls.build_split_contents(task, splits, filter) 714 return cls( 715 parent=task, 716 name=name, 717 description=description, 718 splits=splits, 719 split_contents=split_contents, 720 )
Build a dataset split from a task.
722 @classmethod 723 def build_split_contents( 724 cls, 725 task: "Task", 726 splits: list[DatasetSplitDefinition], 727 filter: DatasetFilter, 728 ) -> dict[str, list[str]]: 729 valid_ids = [] 730 for task_run in task.runs(): 731 if filter(task_run): 732 valid_ids.append(task_run.id) 733 734 # Shuffle and split by split percentage 735 random.shuffle(valid_ids) 736 split_contents = {} 737 start_idx = 0 738 remaining_items = len(valid_ids) 739 740 # Handle all splits except the last one 741 for split in splits[:-1]: 742 split_size = round(len(valid_ids) * split.percentage) 743 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 744 start_idx += split_size 745 remaining_items -= split_size 746 747 # Last split gets all remaining items (for rounding) 748 if splits: 749 split_contents[splits[-1].name] = valid_ids[start_idx:] 750 751 return split_contents
761 def missing_count(self) -> int: 762 """ 763 Returns: 764 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 765 """ 766 parent = self.parent_task() 767 if parent is None: 768 raise ValueError("DatasetSplit has no parent task") 769 770 runs = parent.runs(readonly=True) 771 all_ids = set(run.id for run in runs) 772 all_ids_in_splits = set() 773 for ids in self.split_contents.values(): 774 all_ids_in_splits.update(ids) 775 missing = all_ids_in_splits - all_ids 776 return len(missing)
Returns: int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
122 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 123 """We need to both initialize private attributes and call the user-defined model_post_init 124 method. 125 """ 126 init_private_attributes(self, context) 127 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.
Inherited Members
101class RequirementRating(BaseModel): 102 """Rating for a specific requirement within a task output.""" 103 104 value: float = Field( 105 description="The rating value. Interpretation depends on rating type" 106 ) 107 type: TaskOutputRatingType = Field(description="The type of rating")
Rating for a specific requirement within a task output.
779class Prompt(KilnParentedModel): 780 """ 781 A prompt for a task. 782 """ 783 784 name: str = NAME_FIELD 785 prompt: str = Field( 786 description="The prompt for the task.", 787 min_length=1, 788 ) 789 chain_of_thought_instructions: str | None = Field( 790 default=None, 791 description="Instructions for the model 'thinking' about the requirement prior to answering. Used for chain of thought style prompting. COT will not be used unless this is provided.", 792 )
A prompt for a task.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
122 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 123 """We need to both initialize private attributes and call the user-defined model_post_init 124 method. 125 """ 126 init_private_attributes(self, context) 127 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.