kiln_ai.adapters.fine_tune.dataset_formatter

  1import json
  2import tempfile
  3from enum import Enum
  4from pathlib import Path
  5from typing import Any, Dict, Protocol
  6
  7from kiln_ai.datamodel import DatasetSplit, TaskRun
  8
  9
 10class DatasetFormat(str, Enum):
 11    """Format types for dataset generation"""
 12
 13    CHAT_MESSAGE_RESPONSE_JSONL = "chat_message_response_jsonl"
 14    CHAT_MESSAGE_TOOLCALL_JSONL = "chat_message_toolcall_jsonl"
 15
 16
 17class FormatGenerator(Protocol):
 18    """Protocol for format generators"""
 19
 20    def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ...
 21
 22
 23def generate_chat_message_response(
 24    task_run: TaskRun, system_message: str
 25) -> Dict[str, Any]:
 26    """Generate OpenAI chat format with plaintext response"""
 27    return {
 28        "messages": [
 29            {"role": "system", "content": system_message},
 30            {"role": "user", "content": task_run.input},
 31            {"role": "assistant", "content": task_run.output.output},
 32        ]
 33    }
 34
 35
 36def generate_chat_message_toolcall(
 37    task_run: TaskRun, system_message: str
 38) -> Dict[str, Any]:
 39    """Generate OpenAI chat format with tool call response"""
 40    try:
 41        arguments = json.loads(task_run.output.output)
 42    except json.JSONDecodeError as e:
 43        raise ValueError(f"Invalid JSON in task run output: {e}") from e
 44
 45    return {
 46        "messages": [
 47            {"role": "system", "content": system_message},
 48            {"role": "user", "content": task_run.input},
 49            {
 50                "role": "assistant",
 51                "content": None,
 52                "tool_calls": [
 53                    {
 54                        "id": "call_1",
 55                        "type": "function",
 56                        "function": {
 57                            "name": "task_response",
 58                            "arguments": json.dumps(arguments),
 59                        },
 60                    }
 61                ],
 62            },
 63        ]
 64    }
 65
 66
 67FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {
 68    DatasetFormat.CHAT_MESSAGE_RESPONSE_JSONL: generate_chat_message_response,
 69    DatasetFormat.CHAT_MESSAGE_TOOLCALL_JSONL: generate_chat_message_toolcall,
 70}
 71
 72
 73class DatasetFormatter:
 74    """Handles formatting of datasets into various output formats"""
 75
 76    def __init__(self, dataset: DatasetSplit, system_message: str):
 77        self.dataset = dataset
 78        self.system_message = system_message
 79
 80        task = dataset.parent_task()
 81        if task is None:
 82            raise ValueError("Dataset has no parent task")
 83        self.task = task
 84
 85    def dump_to_file(
 86        self, split_name: str, format_type: DatasetFormat, path: Path | None = None
 87    ) -> Path:
 88        """
 89        Format the dataset into the specified format.
 90
 91        Args:
 92            split_name: Name of the split to dump
 93            format_type: Format to generate the dataset in
 94            path: Optional path to write to. If None, writes to temp directory
 95
 96        Returns:
 97            Path to the generated file
 98        """
 99        if format_type not in FORMAT_GENERATORS:
100            raise ValueError(f"Unsupported format: {format_type}")
101        if split_name not in self.dataset.split_contents:
102            raise ValueError(f"Split {split_name} not found in dataset")
103
104        generator = FORMAT_GENERATORS[format_type]
105
106        # Write to a temp file if no path is provided
107        output_path = (
108            path
109            or Path(tempfile.gettempdir())
110            / f"{self.dataset.name}_{split_name}_{format_type}.jsonl"
111        )
112
113        runs = self.task.runs()
114        runs_by_id = {run.id: run for run in runs}
115
116        # Generate formatted output with UTF-8 encoding
117        with open(output_path, "w", encoding="utf-8") as f:
118            for run_id in self.dataset.split_contents[split_name]:
119                task_run = runs_by_id[run_id]
120                if task_run is None:
121                    raise ValueError(
122                        f"Task run {run_id} not found. This is required by this dataset."
123                    )
124
125                example = generator(task_run, self.system_message)
126                f.write(json.dumps(example) + "\n")
127
128        return output_path
class DatasetFormat(builtins.str, enum.Enum):
11class DatasetFormat(str, Enum):
12    """Format types for dataset generation"""
13
14    CHAT_MESSAGE_RESPONSE_JSONL = "chat_message_response_jsonl"
15    CHAT_MESSAGE_TOOLCALL_JSONL = "chat_message_toolcall_jsonl"

Format types for dataset generation

CHAT_MESSAGE_RESPONSE_JSONL = <DatasetFormat.CHAT_MESSAGE_RESPONSE_JSONL: 'chat_message_response_jsonl'>
CHAT_MESSAGE_TOOLCALL_JSONL = <DatasetFormat.CHAT_MESSAGE_TOOLCALL_JSONL: 'chat_message_toolcall_jsonl'>
class FormatGenerator(typing.Protocol):
18class FormatGenerator(Protocol):
19    """Protocol for format generators"""
20
21    def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ...

Protocol for format generators

FormatGenerator(*args, **kwargs)
1767def _no_init_or_replace_init(self, *args, **kwargs):
1768    cls = type(self)
1769
1770    if cls._is_protocol:
1771        raise TypeError('Protocols cannot be instantiated')
1772
1773    # Already using a custom `__init__`. No need to calculate correct
1774    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1775    if cls.__init__ is not _no_init_or_replace_init:
1776        return
1777
1778    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1779    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1780    # searches for a proper new `__init__` in the MRO. The new `__init__`
1781    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1782    # instantiation of the protocol subclass will thus use the new
1783    # `__init__` and no longer call `_no_init_or_replace_init`.
1784    for base in cls.__mro__:
1785        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1786        if init is not _no_init_or_replace_init:
1787            cls.__init__ = init
1788            break
1789    else:
1790        # should not happen
1791        cls.__init__ = object.__init__
1792
1793    cls.__init__(self, *args, **kwargs)
def generate_chat_message_response( task_run: kiln_ai.datamodel.TaskRun, system_message: str) -> Dict[str, Any]:
24def generate_chat_message_response(
25    task_run: TaskRun, system_message: str
26) -> Dict[str, Any]:
27    """Generate OpenAI chat format with plaintext response"""
28    return {
29        "messages": [
30            {"role": "system", "content": system_message},
31            {"role": "user", "content": task_run.input},
32            {"role": "assistant", "content": task_run.output.output},
33        ]
34    }

Generate OpenAI chat format with plaintext response

def generate_chat_message_toolcall( task_run: kiln_ai.datamodel.TaskRun, system_message: str) -> Dict[str, Any]:
37def generate_chat_message_toolcall(
38    task_run: TaskRun, system_message: str
39) -> Dict[str, Any]:
40    """Generate OpenAI chat format with tool call response"""
41    try:
42        arguments = json.loads(task_run.output.output)
43    except json.JSONDecodeError as e:
44        raise ValueError(f"Invalid JSON in task run output: {e}") from e
45
46    return {
47        "messages": [
48            {"role": "system", "content": system_message},
49            {"role": "user", "content": task_run.input},
50            {
51                "role": "assistant",
52                "content": None,
53                "tool_calls": [
54                    {
55                        "id": "call_1",
56                        "type": "function",
57                        "function": {
58                            "name": "task_response",
59                            "arguments": json.dumps(arguments),
60                        },
61                    }
62                ],
63            },
64        ]
65    }

Generate OpenAI chat format with tool call response

FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {<DatasetFormat.CHAT_MESSAGE_RESPONSE_JSONL: 'chat_message_response_jsonl'>: <function generate_chat_message_response>, <DatasetFormat.CHAT_MESSAGE_TOOLCALL_JSONL: 'chat_message_toolcall_jsonl'>: <function generate_chat_message_toolcall>}
class DatasetFormatter:
 74class DatasetFormatter:
 75    """Handles formatting of datasets into various output formats"""
 76
 77    def __init__(self, dataset: DatasetSplit, system_message: str):
 78        self.dataset = dataset
 79        self.system_message = system_message
 80
 81        task = dataset.parent_task()
 82        if task is None:
 83            raise ValueError("Dataset has no parent task")
 84        self.task = task
 85
 86    def dump_to_file(
 87        self, split_name: str, format_type: DatasetFormat, path: Path | None = None
 88    ) -> Path:
 89        """
 90        Format the dataset into the specified format.
 91
 92        Args:
 93            split_name: Name of the split to dump
 94            format_type: Format to generate the dataset in
 95            path: Optional path to write to. If None, writes to temp directory
 96
 97        Returns:
 98            Path to the generated file
 99        """
100        if format_type not in FORMAT_GENERATORS:
101            raise ValueError(f"Unsupported format: {format_type}")
102        if split_name not in self.dataset.split_contents:
103            raise ValueError(f"Split {split_name} not found in dataset")
104
105        generator = FORMAT_GENERATORS[format_type]
106
107        # Write to a temp file if no path is provided
108        output_path = (
109            path
110            or Path(tempfile.gettempdir())
111            / f"{self.dataset.name}_{split_name}_{format_type}.jsonl"
112        )
113
114        runs = self.task.runs()
115        runs_by_id = {run.id: run for run in runs}
116
117        # Generate formatted output with UTF-8 encoding
118        with open(output_path, "w", encoding="utf-8") as f:
119            for run_id in self.dataset.split_contents[split_name]:
120                task_run = runs_by_id[run_id]
121                if task_run is None:
122                    raise ValueError(
123                        f"Task run {run_id} not found. This is required by this dataset."
124                    )
125
126                example = generator(task_run, self.system_message)
127                f.write(json.dumps(example) + "\n")
128
129        return output_path

Handles formatting of datasets into various output formats

DatasetFormatter(dataset: kiln_ai.datamodel.DatasetSplit, system_message: str)
77    def __init__(self, dataset: DatasetSplit, system_message: str):
78        self.dataset = dataset
79        self.system_message = system_message
80
81        task = dataset.parent_task()
82        if task is None:
83            raise ValueError("Dataset has no parent task")
84        self.task = task
dataset
system_message
task
def dump_to_file( self, split_name: str, format_type: DatasetFormat, path: pathlib.Path | None = None) -> pathlib.Path:
 86    def dump_to_file(
 87        self, split_name: str, format_type: DatasetFormat, path: Path | None = None
 88    ) -> Path:
 89        """
 90        Format the dataset into the specified format.
 91
 92        Args:
 93            split_name: Name of the split to dump
 94            format_type: Format to generate the dataset in
 95            path: Optional path to write to. If None, writes to temp directory
 96
 97        Returns:
 98            Path to the generated file
 99        """
100        if format_type not in FORMAT_GENERATORS:
101            raise ValueError(f"Unsupported format: {format_type}")
102        if split_name not in self.dataset.split_contents:
103            raise ValueError(f"Split {split_name} not found in dataset")
104
105        generator = FORMAT_GENERATORS[format_type]
106
107        # Write to a temp file if no path is provided
108        output_path = (
109            path
110            or Path(tempfile.gettempdir())
111            / f"{self.dataset.name}_{split_name}_{format_type}.jsonl"
112        )
113
114        runs = self.task.runs()
115        runs_by_id = {run.id: run for run in runs}
116
117        # Generate formatted output with UTF-8 encoding
118        with open(output_path, "w", encoding="utf-8") as f:
119            for run_id in self.dataset.split_contents[split_name]:
120                task_run = runs_by_id[run_id]
121                if task_run is None:
122                    raise ValueError(
123                        f"Task run {run_id} not found. This is required by this dataset."
124                    )
125
126                example = generator(task_run, self.system_message)
127                f.write(json.dumps(example) + "\n")
128
129        return output_path

Format the dataset into the specified format.

Args: split_name: Name of the split to dump format_type: Format to generate the dataset in path: Optional path to write to. If None, writes to temp directory

Returns: Path to the generated file