kiln_ai.adapters.fine_tune.dataset_formatter
1import json 2import tempfile 3from enum import Enum 4from pathlib import Path 5from typing import Any, Dict, Protocol 6 7from kiln_ai.datamodel import DatasetSplit, TaskRun 8 9 10class DatasetFormat(str, Enum): 11 """Format types for dataset generation""" 12 13 CHAT_MESSAGE_RESPONSE_JSONL = "chat_message_response_jsonl" 14 CHAT_MESSAGE_TOOLCALL_JSONL = "chat_message_toolcall_jsonl" 15 16 17class FormatGenerator(Protocol): 18 """Protocol for format generators""" 19 20 def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ... 21 22 23def generate_chat_message_response( 24 task_run: TaskRun, system_message: str 25) -> Dict[str, Any]: 26 """Generate OpenAI chat format with plaintext response""" 27 return { 28 "messages": [ 29 {"role": "system", "content": system_message}, 30 {"role": "user", "content": task_run.input}, 31 {"role": "assistant", "content": task_run.output.output}, 32 ] 33 } 34 35 36def generate_chat_message_toolcall( 37 task_run: TaskRun, system_message: str 38) -> Dict[str, Any]: 39 """Generate OpenAI chat format with tool call response""" 40 try: 41 arguments = json.loads(task_run.output.output) 42 except json.JSONDecodeError as e: 43 raise ValueError(f"Invalid JSON in task run output: {e}") from e 44 45 return { 46 "messages": [ 47 {"role": "system", "content": system_message}, 48 {"role": "user", "content": task_run.input}, 49 { 50 "role": "assistant", 51 "content": None, 52 "tool_calls": [ 53 { 54 "id": "call_1", 55 "type": "function", 56 "function": { 57 "name": "task_response", 58 "arguments": json.dumps(arguments), 59 }, 60 } 61 ], 62 }, 63 ] 64 } 65 66 67FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = { 68 DatasetFormat.CHAT_MESSAGE_RESPONSE_JSONL: generate_chat_message_response, 69 DatasetFormat.CHAT_MESSAGE_TOOLCALL_JSONL: generate_chat_message_toolcall, 70} 71 72 73class DatasetFormatter: 74 """Handles formatting of datasets into various output formats""" 75 76 def __init__(self, dataset: DatasetSplit, system_message: str): 77 self.dataset = dataset 78 self.system_message = system_message 79 80 task = dataset.parent_task() 81 if task is None: 82 raise ValueError("Dataset has no parent task") 83 self.task = task 84 85 def dump_to_file( 86 self, split_name: str, format_type: DatasetFormat, path: Path | None = None 87 ) -> Path: 88 """ 89 Format the dataset into the specified format. 90 91 Args: 92 split_name: Name of the split to dump 93 format_type: Format to generate the dataset in 94 path: Optional path to write to. If None, writes to temp directory 95 96 Returns: 97 Path to the generated file 98 """ 99 if format_type not in FORMAT_GENERATORS: 100 raise ValueError(f"Unsupported format: {format_type}") 101 if split_name not in self.dataset.split_contents: 102 raise ValueError(f"Split {split_name} not found in dataset") 103 104 generator = FORMAT_GENERATORS[format_type] 105 106 # Write to a temp file if no path is provided 107 output_path = ( 108 path 109 or Path(tempfile.gettempdir()) 110 / f"{self.dataset.name}_{split_name}_{format_type}.jsonl" 111 ) 112 113 runs = self.task.runs() 114 runs_by_id = {run.id: run for run in runs} 115 116 # Generate formatted output with UTF-8 encoding 117 with open(output_path, "w", encoding="utf-8") as f: 118 for run_id in self.dataset.split_contents[split_name]: 119 task_run = runs_by_id[run_id] 120 if task_run is None: 121 raise ValueError( 122 f"Task run {run_id} not found. This is required by this dataset." 123 ) 124 125 example = generator(task_run, self.system_message) 126 f.write(json.dumps(example) + "\n") 127 128 return output_path
class
DatasetFormat(builtins.str, enum.Enum):
11class DatasetFormat(str, Enum): 12 """Format types for dataset generation""" 13 14 CHAT_MESSAGE_RESPONSE_JSONL = "chat_message_response_jsonl" 15 CHAT_MESSAGE_TOOLCALL_JSONL = "chat_message_toolcall_jsonl"
Format types for dataset generation
CHAT_MESSAGE_RESPONSE_JSONL =
<DatasetFormat.CHAT_MESSAGE_RESPONSE_JSONL: 'chat_message_response_jsonl'>
CHAT_MESSAGE_TOOLCALL_JSONL =
<DatasetFormat.CHAT_MESSAGE_TOOLCALL_JSONL: 'chat_message_toolcall_jsonl'>
class
FormatGenerator(typing.Protocol):
18class FormatGenerator(Protocol): 19 """Protocol for format generators""" 20 21 def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ...
Protocol for format generators
FormatGenerator(*args, **kwargs)
1767def _no_init_or_replace_init(self, *args, **kwargs): 1768 cls = type(self) 1769 1770 if cls._is_protocol: 1771 raise TypeError('Protocols cannot be instantiated') 1772 1773 # Already using a custom `__init__`. No need to calculate correct 1774 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1775 if cls.__init__ is not _no_init_or_replace_init: 1776 return 1777 1778 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1779 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1780 # searches for a proper new `__init__` in the MRO. The new `__init__` 1781 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1782 # instantiation of the protocol subclass will thus use the new 1783 # `__init__` and no longer call `_no_init_or_replace_init`. 1784 for base in cls.__mro__: 1785 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1786 if init is not _no_init_or_replace_init: 1787 cls.__init__ = init 1788 break 1789 else: 1790 # should not happen 1791 cls.__init__ = object.__init__ 1792 1793 cls.__init__(self, *args, **kwargs)
def
generate_chat_message_response( task_run: kiln_ai.datamodel.TaskRun, system_message: str) -> Dict[str, Any]:
24def generate_chat_message_response( 25 task_run: TaskRun, system_message: str 26) -> Dict[str, Any]: 27 """Generate OpenAI chat format with plaintext response""" 28 return { 29 "messages": [ 30 {"role": "system", "content": system_message}, 31 {"role": "user", "content": task_run.input}, 32 {"role": "assistant", "content": task_run.output.output}, 33 ] 34 }
Generate OpenAI chat format with plaintext response
def
generate_chat_message_toolcall( task_run: kiln_ai.datamodel.TaskRun, system_message: str) -> Dict[str, Any]:
37def generate_chat_message_toolcall( 38 task_run: TaskRun, system_message: str 39) -> Dict[str, Any]: 40 """Generate OpenAI chat format with tool call response""" 41 try: 42 arguments = json.loads(task_run.output.output) 43 except json.JSONDecodeError as e: 44 raise ValueError(f"Invalid JSON in task run output: {e}") from e 45 46 return { 47 "messages": [ 48 {"role": "system", "content": system_message}, 49 {"role": "user", "content": task_run.input}, 50 { 51 "role": "assistant", 52 "content": None, 53 "tool_calls": [ 54 { 55 "id": "call_1", 56 "type": "function", 57 "function": { 58 "name": "task_response", 59 "arguments": json.dumps(arguments), 60 }, 61 } 62 ], 63 }, 64 ] 65 }
Generate OpenAI chat format with tool call response
FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] =
{<DatasetFormat.CHAT_MESSAGE_RESPONSE_JSONL: 'chat_message_response_jsonl'>: <function generate_chat_message_response>, <DatasetFormat.CHAT_MESSAGE_TOOLCALL_JSONL: 'chat_message_toolcall_jsonl'>: <function generate_chat_message_toolcall>}
class
DatasetFormatter:
74class DatasetFormatter: 75 """Handles formatting of datasets into various output formats""" 76 77 def __init__(self, dataset: DatasetSplit, system_message: str): 78 self.dataset = dataset 79 self.system_message = system_message 80 81 task = dataset.parent_task() 82 if task is None: 83 raise ValueError("Dataset has no parent task") 84 self.task = task 85 86 def dump_to_file( 87 self, split_name: str, format_type: DatasetFormat, path: Path | None = None 88 ) -> Path: 89 """ 90 Format the dataset into the specified format. 91 92 Args: 93 split_name: Name of the split to dump 94 format_type: Format to generate the dataset in 95 path: Optional path to write to. If None, writes to temp directory 96 97 Returns: 98 Path to the generated file 99 """ 100 if format_type not in FORMAT_GENERATORS: 101 raise ValueError(f"Unsupported format: {format_type}") 102 if split_name not in self.dataset.split_contents: 103 raise ValueError(f"Split {split_name} not found in dataset") 104 105 generator = FORMAT_GENERATORS[format_type] 106 107 # Write to a temp file if no path is provided 108 output_path = ( 109 path 110 or Path(tempfile.gettempdir()) 111 / f"{self.dataset.name}_{split_name}_{format_type}.jsonl" 112 ) 113 114 runs = self.task.runs() 115 runs_by_id = {run.id: run for run in runs} 116 117 # Generate formatted output with UTF-8 encoding 118 with open(output_path, "w", encoding="utf-8") as f: 119 for run_id in self.dataset.split_contents[split_name]: 120 task_run = runs_by_id[run_id] 121 if task_run is None: 122 raise ValueError( 123 f"Task run {run_id} not found. This is required by this dataset." 124 ) 125 126 example = generator(task_run, self.system_message) 127 f.write(json.dumps(example) + "\n") 128 129 return output_path
Handles formatting of datasets into various output formats
def
dump_to_file( self, split_name: str, format_type: DatasetFormat, path: pathlib.Path | None = None) -> pathlib.Path:
86 def dump_to_file( 87 self, split_name: str, format_type: DatasetFormat, path: Path | None = None 88 ) -> Path: 89 """ 90 Format the dataset into the specified format. 91 92 Args: 93 split_name: Name of the split to dump 94 format_type: Format to generate the dataset in 95 path: Optional path to write to. If None, writes to temp directory 96 97 Returns: 98 Path to the generated file 99 """ 100 if format_type not in FORMAT_GENERATORS: 101 raise ValueError(f"Unsupported format: {format_type}") 102 if split_name not in self.dataset.split_contents: 103 raise ValueError(f"Split {split_name} not found in dataset") 104 105 generator = FORMAT_GENERATORS[format_type] 106 107 # Write to a temp file if no path is provided 108 output_path = ( 109 path 110 or Path(tempfile.gettempdir()) 111 / f"{self.dataset.name}_{split_name}_{format_type}.jsonl" 112 ) 113 114 runs = self.task.runs() 115 runs_by_id = {run.id: run for run in runs} 116 117 # Generate formatted output with UTF-8 encoding 118 with open(output_path, "w", encoding="utf-8") as f: 119 for run_id in self.dataset.split_contents[split_name]: 120 task_run = runs_by_id[run_id] 121 if task_run is None: 122 raise ValueError( 123 f"Task run {run_id} not found. This is required by this dataset." 124 ) 125 126 example = generator(task_run, self.system_message) 127 f.write(json.dumps(example) + "\n") 128 129 return output_path
Format the dataset into the specified format.
Args: split_name: Name of the split to dump format_type: Format to generate the dataset in path: Optional path to write to. If None, writes to temp directory
Returns: Path to the generated file