kiln_ai.adapters.fine_tune.dataset_formatter

  1import json
  2import tempfile
  3from enum import Enum
  4from pathlib import Path
  5from typing import Any, Dict, Protocol
  6from uuid import uuid4
  7
  8from kiln_ai.datamodel import DatasetSplit, TaskRun
  9
 10
 11class DatasetFormat(str, Enum):
 12    """Formats for dataset generation. Both for file format (like JSONL), and internal structure (like chat/toolcall)"""
 13
 14    """OpenAI chat format with plaintext response"""
 15    OPENAI_CHAT_JSONL = "openai_chat_jsonl"
 16
 17    """OpenAI chat format with json response_format"""
 18    OPENAI_CHAT_JSON_SCHEMA_JSONL = "openai_chat_json_schema_jsonl"
 19
 20    """OpenAI chat format with tool call response"""
 21    OPENAI_CHAT_TOOLCALL_JSONL = "openai_chat_toolcall_jsonl"
 22
 23    """HuggingFace chat template in JSONL"""
 24    HUGGINGFACE_CHAT_TEMPLATE_JSONL = "huggingface_chat_template_jsonl"
 25
 26    """HuggingFace chat template with tool calls in JSONL"""
 27    HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL = (
 28        "huggingface_chat_template_toolcall_jsonl"
 29    )
 30
 31    """Vertex Gemini 1.5 format (flash and pro)"""
 32    VERTEX_GEMINI_1_5 = "vertex_gemini_1_5"
 33
 34
 35class FormatGenerator(Protocol):
 36    """Protocol for format generators"""
 37
 38    def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ...
 39
 40
 41def best_task_output(task_run: TaskRun) -> str:
 42    """Get the best task output from the task run, preferring repaired output if available"""
 43    if task_run.repaired_output is not None:
 44        return task_run.repaired_output.output
 45    return task_run.output.output
 46
 47
 48def generate_chat_message_response(
 49    task_run: TaskRun, system_message: str
 50) -> Dict[str, Any]:
 51    """Generate OpenAI chat format with plaintext response"""
 52    return {
 53        "messages": [
 54            {"role": "system", "content": system_message},
 55            {"role": "user", "content": task_run.input},
 56            {"role": "assistant", "content": best_task_output(task_run)},
 57        ]
 58    }
 59
 60
 61def generate_json_schema_message(
 62    task_run: TaskRun, system_message: str
 63) -> Dict[str, Any]:
 64    """Generate OpenAI chat format with tool call response"""
 65    # Load and dump to ensure it's valid JSON and goes to 1 line
 66    try:
 67        json_data = json.loads(best_task_output(task_run))
 68    except json.JSONDecodeError as e:
 69        raise ValueError(f"Invalid JSON in for tool call: {e}") from e
 70    json_string = json.dumps(json_data)
 71
 72    return {
 73        "messages": [
 74            {"role": "system", "content": system_message},
 75            {"role": "user", "content": task_run.input},
 76            {"role": "assistant", "content": json_string},
 77        ]
 78    }
 79
 80
 81def generate_chat_message_toolcall(
 82    task_run: TaskRun, system_message: str
 83) -> Dict[str, Any]:
 84    """Generate OpenAI chat format with tool call response"""
 85    try:
 86        arguments = json.loads(best_task_output(task_run))
 87    except json.JSONDecodeError as e:
 88        raise ValueError(f"Invalid JSON in for tool call: {e}") from e
 89
 90    return {
 91        "messages": [
 92            {"role": "system", "content": system_message},
 93            {"role": "user", "content": task_run.input},
 94            {
 95                "role": "assistant",
 96                "content": None,
 97                "tool_calls": [
 98                    {
 99                        "id": "call_1",
100                        "type": "function",
101                        "function": {
102                            "name": "task_response",
103                            # Yes we parse then dump again. This ensures it's valid JSON, and ensures it goes to 1 line
104                            "arguments": json.dumps(arguments, ensure_ascii=False),
105                        },
106                    }
107                ],
108            },
109        ]
110    }
111
112
113def generate_huggingface_chat_template(
114    task_run: TaskRun, system_message: str
115) -> Dict[str, Any]:
116    """Generate HuggingFace chat template"""
117    return {
118        "conversations": [
119            {"role": "system", "content": system_message},
120            {"role": "user", "content": task_run.input},
121            {"role": "assistant", "content": best_task_output(task_run)},
122        ]
123    }
124
125
126def generate_huggingface_chat_template_toolcall(
127    task_run: TaskRun, system_message: str
128) -> Dict[str, Any]:
129    """Generate HuggingFace chat template with tool calls"""
130    try:
131        arguments = json.loads(best_task_output(task_run))
132    except json.JSONDecodeError as e:
133        raise ValueError(f"Invalid JSON in for tool call: {e}") from e
134
135    # See https://huggingface.co/docs/transformers/en/chat_templating
136    return {
137        "conversations": [
138            {"role": "system", "content": system_message},
139            {"role": "user", "content": task_run.input},
140            {
141                "role": "assistant",
142                "tool_calls": [
143                    {
144                        "type": "function",
145                        "function": {
146                            "name": "task_response",
147                            "id": str(uuid4()).replace("-", "")[:9],
148                            "arguments": arguments,
149                        },
150                    }
151                ],
152            },
153        ]
154    }
155
156
157def generate_vertex_gemini_1_5(
158    task_run: TaskRun, system_message: str
159) -> Dict[str, Any]:
160    """Generate Vertex Gemini 1.5 format (flash and pro)"""
161    # See https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare
162    return {
163        "systemInstruction": {
164            "role": "system",
165            "parts": [
166                {
167                    "text": system_message,
168                }
169            ],
170        },
171        "contents": [
172            {
173                "role": "user",
174                "parts": [
175                    {
176                        "text": task_run.input,
177                    }
178                ],
179            },
180            {
181                "role": "model",
182                "parts": [
183                    {
184                        "text": best_task_output(task_run),
185                    }
186                ],
187            },
188        ],
189    }
190
191
192FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {
193    DatasetFormat.OPENAI_CHAT_JSONL: generate_chat_message_response,
194    DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL: generate_json_schema_message,
195    DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: generate_chat_message_toolcall,
196    DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: generate_huggingface_chat_template,
197    DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: generate_huggingface_chat_template_toolcall,
198    DatasetFormat.VERTEX_GEMINI_1_5: generate_vertex_gemini_1_5,
199}
200
201
202class DatasetFormatter:
203    """Handles formatting of datasets into various output formats"""
204
205    def __init__(self, dataset: DatasetSplit, system_message: str):
206        self.dataset = dataset
207        self.system_message = system_message
208
209        task = dataset.parent_task()
210        if task is None:
211            raise ValueError("Dataset has no parent task")
212        self.task = task
213
214    def dump_to_file(
215        self, split_name: str, format_type: DatasetFormat, path: Path | None = None
216    ) -> Path:
217        """
218        Format the dataset into the specified format.
219
220        Args:
221            split_name: Name of the split to dump
222            format_type: Format to generate the dataset in
223            path: Optional path to write to. If None, writes to temp directory
224
225        Returns:
226            Path to the generated file
227
228        Note:
229            The output is written in UTF-8 encoding with ensure_ascii=False to properly
230            support international text content while maintaining readability.
231        """
232        if format_type not in FORMAT_GENERATORS:
233            raise ValueError(f"Unsupported format: {format_type}")
234        if split_name not in self.dataset.split_contents:
235            raise ValueError(f"Split {split_name} not found in dataset")
236
237        generator = FORMAT_GENERATORS[format_type]
238
239        # Write to a temp file if no path is provided
240        output_path = (
241            path
242            or Path(tempfile.gettempdir())
243            / f"{self.dataset.name}_{split_name}_{format_type}.jsonl"
244        )
245
246        runs = self.task.runs()
247        runs_by_id = {run.id: run for run in runs}
248
249        # Generate formatted output with UTF-8 encoding
250        with open(output_path, "w", encoding="utf-8") as f:
251            for run_id in self.dataset.split_contents[split_name]:
252                task_run = runs_by_id[run_id]
253                if task_run is None:
254                    raise ValueError(
255                        f"Task run {run_id} not found. This is required by this dataset."
256                    )
257
258                example = generator(task_run, self.system_message)
259                # Allow non-ascii characters in the dataset.
260                # Better readability for non-English users. If you don't support UTF-8... you should.
261                f.write(json.dumps(example, ensure_ascii=False) + "\n")
262
263        return output_path
class DatasetFormat(builtins.str, enum.Enum):
12class DatasetFormat(str, Enum):
13    """Formats for dataset generation. Both for file format (like JSONL), and internal structure (like chat/toolcall)"""
14
15    """OpenAI chat format with plaintext response"""
16    OPENAI_CHAT_JSONL = "openai_chat_jsonl"
17
18    """OpenAI chat format with json response_format"""
19    OPENAI_CHAT_JSON_SCHEMA_JSONL = "openai_chat_json_schema_jsonl"
20
21    """OpenAI chat format with tool call response"""
22    OPENAI_CHAT_TOOLCALL_JSONL = "openai_chat_toolcall_jsonl"
23
24    """HuggingFace chat template in JSONL"""
25    HUGGINGFACE_CHAT_TEMPLATE_JSONL = "huggingface_chat_template_jsonl"
26
27    """HuggingFace chat template with tool calls in JSONL"""
28    HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL = (
29        "huggingface_chat_template_toolcall_jsonl"
30    )
31
32    """Vertex Gemini 1.5 format (flash and pro)"""
33    VERTEX_GEMINI_1_5 = "vertex_gemini_1_5"

Formats for dataset generation. Both for file format (like JSONL), and internal structure (like chat/toolcall)

OPENAI_CHAT_JSONL = <DatasetFormat.OPENAI_CHAT_JSONL: 'openai_chat_jsonl'>

OpenAI chat format with json response_format

OPENAI_CHAT_JSON_SCHEMA_JSONL = <DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL: 'openai_chat_json_schema_jsonl'>

OpenAI chat format with tool call response

OPENAI_CHAT_TOOLCALL_JSONL = <DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: 'openai_chat_toolcall_jsonl'>

HuggingFace chat template in JSONL

HUGGINGFACE_CHAT_TEMPLATE_JSONL = <DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: 'huggingface_chat_template_jsonl'>

HuggingFace chat template with tool calls in JSONL

HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL = <DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: 'huggingface_chat_template_toolcall_jsonl'>

Vertex Gemini 1.5 format (flash and pro)

VERTEX_GEMINI_1_5 = <DatasetFormat.VERTEX_GEMINI_1_5: 'vertex_gemini_1_5'>
class FormatGenerator(typing.Protocol):
36class FormatGenerator(Protocol):
37    """Protocol for format generators"""
38
39    def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ...

Protocol for format generators

FormatGenerator(*args, **kwargs)
1941def _no_init_or_replace_init(self, *args, **kwargs):
1942    cls = type(self)
1943
1944    if cls._is_protocol:
1945        raise TypeError('Protocols cannot be instantiated')
1946
1947    # Already using a custom `__init__`. No need to calculate correct
1948    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1949    if cls.__init__ is not _no_init_or_replace_init:
1950        return
1951
1952    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1953    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1954    # searches for a proper new `__init__` in the MRO. The new `__init__`
1955    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1956    # instantiation of the protocol subclass will thus use the new
1957    # `__init__` and no longer call `_no_init_or_replace_init`.
1958    for base in cls.__mro__:
1959        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1960        if init is not _no_init_or_replace_init:
1961            cls.__init__ = init
1962            break
1963    else:
1964        # should not happen
1965        cls.__init__ = object.__init__
1966
1967    cls.__init__(self, *args, **kwargs)
def best_task_output(task_run: kiln_ai.datamodel.TaskRun) -> str:
42def best_task_output(task_run: TaskRun) -> str:
43    """Get the best task output from the task run, preferring repaired output if available"""
44    if task_run.repaired_output is not None:
45        return task_run.repaired_output.output
46    return task_run.output.output

Get the best task output from the task run, preferring repaired output if available

def generate_chat_message_response( task_run: kiln_ai.datamodel.TaskRun, system_message: str) -> Dict[str, Any]:
49def generate_chat_message_response(
50    task_run: TaskRun, system_message: str
51) -> Dict[str, Any]:
52    """Generate OpenAI chat format with plaintext response"""
53    return {
54        "messages": [
55            {"role": "system", "content": system_message},
56            {"role": "user", "content": task_run.input},
57            {"role": "assistant", "content": best_task_output(task_run)},
58        ]
59    }

Generate OpenAI chat format with plaintext response

def generate_json_schema_message( task_run: kiln_ai.datamodel.TaskRun, system_message: str) -> Dict[str, Any]:
62def generate_json_schema_message(
63    task_run: TaskRun, system_message: str
64) -> Dict[str, Any]:
65    """Generate OpenAI chat format with tool call response"""
66    # Load and dump to ensure it's valid JSON and goes to 1 line
67    try:
68        json_data = json.loads(best_task_output(task_run))
69    except json.JSONDecodeError as e:
70        raise ValueError(f"Invalid JSON in for tool call: {e}") from e
71    json_string = json.dumps(json_data)
72
73    return {
74        "messages": [
75            {"role": "system", "content": system_message},
76            {"role": "user", "content": task_run.input},
77            {"role": "assistant", "content": json_string},
78        ]
79    }

Generate OpenAI chat format with tool call response

def generate_chat_message_toolcall( task_run: kiln_ai.datamodel.TaskRun, system_message: str) -> Dict[str, Any]:
 82def generate_chat_message_toolcall(
 83    task_run: TaskRun, system_message: str
 84) -> Dict[str, Any]:
 85    """Generate OpenAI chat format with tool call response"""
 86    try:
 87        arguments = json.loads(best_task_output(task_run))
 88    except json.JSONDecodeError as e:
 89        raise ValueError(f"Invalid JSON in for tool call: {e}") from e
 90
 91    return {
 92        "messages": [
 93            {"role": "system", "content": system_message},
 94            {"role": "user", "content": task_run.input},
 95            {
 96                "role": "assistant",
 97                "content": None,
 98                "tool_calls": [
 99                    {
100                        "id": "call_1",
101                        "type": "function",
102                        "function": {
103                            "name": "task_response",
104                            # Yes we parse then dump again. This ensures it's valid JSON, and ensures it goes to 1 line
105                            "arguments": json.dumps(arguments, ensure_ascii=False),
106                        },
107                    }
108                ],
109            },
110        ]
111    }

Generate OpenAI chat format with tool call response

def generate_huggingface_chat_template( task_run: kiln_ai.datamodel.TaskRun, system_message: str) -> Dict[str, Any]:
114def generate_huggingface_chat_template(
115    task_run: TaskRun, system_message: str
116) -> Dict[str, Any]:
117    """Generate HuggingFace chat template"""
118    return {
119        "conversations": [
120            {"role": "system", "content": system_message},
121            {"role": "user", "content": task_run.input},
122            {"role": "assistant", "content": best_task_output(task_run)},
123        ]
124    }

Generate HuggingFace chat template

def generate_huggingface_chat_template_toolcall( task_run: kiln_ai.datamodel.TaskRun, system_message: str) -> Dict[str, Any]:
127def generate_huggingface_chat_template_toolcall(
128    task_run: TaskRun, system_message: str
129) -> Dict[str, Any]:
130    """Generate HuggingFace chat template with tool calls"""
131    try:
132        arguments = json.loads(best_task_output(task_run))
133    except json.JSONDecodeError as e:
134        raise ValueError(f"Invalid JSON in for tool call: {e}") from e
135
136    # See https://huggingface.co/docs/transformers/en/chat_templating
137    return {
138        "conversations": [
139            {"role": "system", "content": system_message},
140            {"role": "user", "content": task_run.input},
141            {
142                "role": "assistant",
143                "tool_calls": [
144                    {
145                        "type": "function",
146                        "function": {
147                            "name": "task_response",
148                            "id": str(uuid4()).replace("-", "")[:9],
149                            "arguments": arguments,
150                        },
151                    }
152                ],
153            },
154        ]
155    }

Generate HuggingFace chat template with tool calls

def generate_vertex_gemini_1_5( task_run: kiln_ai.datamodel.TaskRun, system_message: str) -> Dict[str, Any]:
158def generate_vertex_gemini_1_5(
159    task_run: TaskRun, system_message: str
160) -> Dict[str, Any]:
161    """Generate Vertex Gemini 1.5 format (flash and pro)"""
162    # See https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare
163    return {
164        "systemInstruction": {
165            "role": "system",
166            "parts": [
167                {
168                    "text": system_message,
169                }
170            ],
171        },
172        "contents": [
173            {
174                "role": "user",
175                "parts": [
176                    {
177                        "text": task_run.input,
178                    }
179                ],
180            },
181            {
182                "role": "model",
183                "parts": [
184                    {
185                        "text": best_task_output(task_run),
186                    }
187                ],
188            },
189        ],
190    }

Generate Vertex Gemini 1.5 format (flash and pro)

FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {<DatasetFormat.OPENAI_CHAT_JSONL: 'openai_chat_jsonl'>: <function generate_chat_message_response>, <DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL: 'openai_chat_json_schema_jsonl'>: <function generate_json_schema_message>, <DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: 'openai_chat_toolcall_jsonl'>: <function generate_chat_message_toolcall>, <DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: 'huggingface_chat_template_jsonl'>: <function generate_huggingface_chat_template>, <DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: 'huggingface_chat_template_toolcall_jsonl'>: <function generate_huggingface_chat_template_toolcall>, <DatasetFormat.VERTEX_GEMINI_1_5: 'vertex_gemini_1_5'>: <function generate_vertex_gemini_1_5>}
class DatasetFormatter:
203class DatasetFormatter:
204    """Handles formatting of datasets into various output formats"""
205
206    def __init__(self, dataset: DatasetSplit, system_message: str):
207        self.dataset = dataset
208        self.system_message = system_message
209
210        task = dataset.parent_task()
211        if task is None:
212            raise ValueError("Dataset has no parent task")
213        self.task = task
214
215    def dump_to_file(
216        self, split_name: str, format_type: DatasetFormat, path: Path | None = None
217    ) -> Path:
218        """
219        Format the dataset into the specified format.
220
221        Args:
222            split_name: Name of the split to dump
223            format_type: Format to generate the dataset in
224            path: Optional path to write to. If None, writes to temp directory
225
226        Returns:
227            Path to the generated file
228
229        Note:
230            The output is written in UTF-8 encoding with ensure_ascii=False to properly
231            support international text content while maintaining readability.
232        """
233        if format_type not in FORMAT_GENERATORS:
234            raise ValueError(f"Unsupported format: {format_type}")
235        if split_name not in self.dataset.split_contents:
236            raise ValueError(f"Split {split_name} not found in dataset")
237
238        generator = FORMAT_GENERATORS[format_type]
239
240        # Write to a temp file if no path is provided
241        output_path = (
242            path
243            or Path(tempfile.gettempdir())
244            / f"{self.dataset.name}_{split_name}_{format_type}.jsonl"
245        )
246
247        runs = self.task.runs()
248        runs_by_id = {run.id: run for run in runs}
249
250        # Generate formatted output with UTF-8 encoding
251        with open(output_path, "w", encoding="utf-8") as f:
252            for run_id in self.dataset.split_contents[split_name]:
253                task_run = runs_by_id[run_id]
254                if task_run is None:
255                    raise ValueError(
256                        f"Task run {run_id} not found. This is required by this dataset."
257                    )
258
259                example = generator(task_run, self.system_message)
260                # Allow non-ascii characters in the dataset.
261                # Better readability for non-English users. If you don't support UTF-8... you should.
262                f.write(json.dumps(example, ensure_ascii=False) + "\n")
263
264        return output_path

Handles formatting of datasets into various output formats

DatasetFormatter(dataset: kiln_ai.datamodel.DatasetSplit, system_message: str)
206    def __init__(self, dataset: DatasetSplit, system_message: str):
207        self.dataset = dataset
208        self.system_message = system_message
209
210        task = dataset.parent_task()
211        if task is None:
212            raise ValueError("Dataset has no parent task")
213        self.task = task
dataset
system_message
task
def dump_to_file( self, split_name: str, format_type: DatasetFormat, path: pathlib._local.Path | None = None) -> pathlib._local.Path:
215    def dump_to_file(
216        self, split_name: str, format_type: DatasetFormat, path: Path | None = None
217    ) -> Path:
218        """
219        Format the dataset into the specified format.
220
221        Args:
222            split_name: Name of the split to dump
223            format_type: Format to generate the dataset in
224            path: Optional path to write to. If None, writes to temp directory
225
226        Returns:
227            Path to the generated file
228
229        Note:
230            The output is written in UTF-8 encoding with ensure_ascii=False to properly
231            support international text content while maintaining readability.
232        """
233        if format_type not in FORMAT_GENERATORS:
234            raise ValueError(f"Unsupported format: {format_type}")
235        if split_name not in self.dataset.split_contents:
236            raise ValueError(f"Split {split_name} not found in dataset")
237
238        generator = FORMAT_GENERATORS[format_type]
239
240        # Write to a temp file if no path is provided
241        output_path = (
242            path
243            or Path(tempfile.gettempdir())
244            / f"{self.dataset.name}_{split_name}_{format_type}.jsonl"
245        )
246
247        runs = self.task.runs()
248        runs_by_id = {run.id: run for run in runs}
249
250        # Generate formatted output with UTF-8 encoding
251        with open(output_path, "w", encoding="utf-8") as f:
252            for run_id in self.dataset.split_contents[split_name]:
253                task_run = runs_by_id[run_id]
254                if task_run is None:
255                    raise ValueError(
256                        f"Task run {run_id} not found. This is required by this dataset."
257                    )
258
259                example = generator(task_run, self.system_message)
260                # Allow non-ascii characters in the dataset.
261                # Better readability for non-English users. If you don't support UTF-8... you should.
262                f.write(json.dumps(example, ensure_ascii=False) + "\n")
263
264        return output_path

Format the dataset into the specified format.

Args: split_name: Name of the split to dump format_type: Format to generate the dataset in path: Optional path to write to. If None, writes to temp directory

Returns: Path to the generated file

Note: The output is written in UTF-8 encoding with ensure_ascii=False to properly support international text content while maintaining readability.