kiln_ai.adapters.fine_tune.dataset_formatter
1import json 2import tempfile 3from enum import Enum 4from pathlib import Path 5from typing import Any, Dict, Protocol 6from uuid import uuid4 7 8from kiln_ai.datamodel import DatasetSplit, TaskRun 9 10 11class DatasetFormat(str, Enum): 12 """Formats for dataset generation. Both for file format (like JSONL), and internal structure (like chat/toolcall)""" 13 14 """OpenAI chat format with plaintext response""" 15 OPENAI_CHAT_JSONL = "openai_chat_jsonl" 16 17 """OpenAI chat format with json response_format""" 18 OPENAI_CHAT_JSON_SCHEMA_JSONL = "openai_chat_json_schema_jsonl" 19 20 """OpenAI chat format with tool call response""" 21 OPENAI_CHAT_TOOLCALL_JSONL = "openai_chat_toolcall_jsonl" 22 23 """HuggingFace chat template in JSONL""" 24 HUGGINGFACE_CHAT_TEMPLATE_JSONL = "huggingface_chat_template_jsonl" 25 26 """HuggingFace chat template with tool calls in JSONL""" 27 HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL = ( 28 "huggingface_chat_template_toolcall_jsonl" 29 ) 30 31 """Vertex Gemini 1.5 format (flash and pro)""" 32 VERTEX_GEMINI_1_5 = "vertex_gemini_1_5" 33 34 35class FormatGenerator(Protocol): 36 """Protocol for format generators""" 37 38 def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ... 39 40 41def best_task_output(task_run: TaskRun) -> str: 42 """Get the best task output from the task run, preferring repaired output if available""" 43 if task_run.repaired_output is not None: 44 return task_run.repaired_output.output 45 return task_run.output.output 46 47 48def generate_chat_message_response( 49 task_run: TaskRun, system_message: str 50) -> Dict[str, Any]: 51 """Generate OpenAI chat format with plaintext response""" 52 return { 53 "messages": [ 54 {"role": "system", "content": system_message}, 55 {"role": "user", "content": task_run.input}, 56 {"role": "assistant", "content": best_task_output(task_run)}, 57 ] 58 } 59 60 61def generate_json_schema_message( 62 task_run: TaskRun, system_message: str 63) -> Dict[str, Any]: 64 """Generate OpenAI chat format with tool call response""" 65 # Load and dump to ensure it's valid JSON and goes to 1 line 66 try: 67 json_data = json.loads(best_task_output(task_run)) 68 except json.JSONDecodeError as e: 69 raise ValueError(f"Invalid JSON in for tool call: {e}") from e 70 json_string = json.dumps(json_data) 71 72 return { 73 "messages": [ 74 {"role": "system", "content": system_message}, 75 {"role": "user", "content": task_run.input}, 76 {"role": "assistant", "content": json_string}, 77 ] 78 } 79 80 81def generate_chat_message_toolcall( 82 task_run: TaskRun, system_message: str 83) -> Dict[str, Any]: 84 """Generate OpenAI chat format with tool call response""" 85 try: 86 arguments = json.loads(best_task_output(task_run)) 87 except json.JSONDecodeError as e: 88 raise ValueError(f"Invalid JSON in for tool call: {e}") from e 89 90 return { 91 "messages": [ 92 {"role": "system", "content": system_message}, 93 {"role": "user", "content": task_run.input}, 94 { 95 "role": "assistant", 96 "content": None, 97 "tool_calls": [ 98 { 99 "id": "call_1", 100 "type": "function", 101 "function": { 102 "name": "task_response", 103 # Yes we parse then dump again. This ensures it's valid JSON, and ensures it goes to 1 line 104 "arguments": json.dumps(arguments, ensure_ascii=False), 105 }, 106 } 107 ], 108 }, 109 ] 110 } 111 112 113def generate_huggingface_chat_template( 114 task_run: TaskRun, system_message: str 115) -> Dict[str, Any]: 116 """Generate HuggingFace chat template""" 117 return { 118 "conversations": [ 119 {"role": "system", "content": system_message}, 120 {"role": "user", "content": task_run.input}, 121 {"role": "assistant", "content": best_task_output(task_run)}, 122 ] 123 } 124 125 126def generate_huggingface_chat_template_toolcall( 127 task_run: TaskRun, system_message: str 128) -> Dict[str, Any]: 129 """Generate HuggingFace chat template with tool calls""" 130 try: 131 arguments = json.loads(best_task_output(task_run)) 132 except json.JSONDecodeError as e: 133 raise ValueError(f"Invalid JSON in for tool call: {e}") from e 134 135 # See https://huggingface.co/docs/transformers/en/chat_templating 136 return { 137 "conversations": [ 138 {"role": "system", "content": system_message}, 139 {"role": "user", "content": task_run.input}, 140 { 141 "role": "assistant", 142 "tool_calls": [ 143 { 144 "type": "function", 145 "function": { 146 "name": "task_response", 147 "id": str(uuid4()).replace("-", "")[:9], 148 "arguments": arguments, 149 }, 150 } 151 ], 152 }, 153 ] 154 } 155 156 157def generate_vertex_gemini_1_5( 158 task_run: TaskRun, system_message: str 159) -> Dict[str, Any]: 160 """Generate Vertex Gemini 1.5 format (flash and pro)""" 161 # See https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare 162 return { 163 "systemInstruction": { 164 "role": "system", 165 "parts": [ 166 { 167 "text": system_message, 168 } 169 ], 170 }, 171 "contents": [ 172 { 173 "role": "user", 174 "parts": [ 175 { 176 "text": task_run.input, 177 } 178 ], 179 }, 180 { 181 "role": "model", 182 "parts": [ 183 { 184 "text": best_task_output(task_run), 185 } 186 ], 187 }, 188 ], 189 } 190 191 192FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = { 193 DatasetFormat.OPENAI_CHAT_JSONL: generate_chat_message_response, 194 DatasetFormat.OPENAI_CHAT_JSON_SCHEMA_JSONL: generate_json_schema_message, 195 DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: generate_chat_message_toolcall, 196 DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: generate_huggingface_chat_template, 197 DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: generate_huggingface_chat_template_toolcall, 198 DatasetFormat.VERTEX_GEMINI_1_5: generate_vertex_gemini_1_5, 199} 200 201 202class DatasetFormatter: 203 """Handles formatting of datasets into various output formats""" 204 205 def __init__(self, dataset: DatasetSplit, system_message: str): 206 self.dataset = dataset 207 self.system_message = system_message 208 209 task = dataset.parent_task() 210 if task is None: 211 raise ValueError("Dataset has no parent task") 212 self.task = task 213 214 def dump_to_file( 215 self, split_name: str, format_type: DatasetFormat, path: Path | None = None 216 ) -> Path: 217 """ 218 Format the dataset into the specified format. 219 220 Args: 221 split_name: Name of the split to dump 222 format_type: Format to generate the dataset in 223 path: Optional path to write to. If None, writes to temp directory 224 225 Returns: 226 Path to the generated file 227 228 Note: 229 The output is written in UTF-8 encoding with ensure_ascii=False to properly 230 support international text content while maintaining readability. 231 """ 232 if format_type not in FORMAT_GENERATORS: 233 raise ValueError(f"Unsupported format: {format_type}") 234 if split_name not in self.dataset.split_contents: 235 raise ValueError(f"Split {split_name} not found in dataset") 236 237 generator = FORMAT_GENERATORS[format_type] 238 239 # Write to a temp file if no path is provided 240 output_path = ( 241 path 242 or Path(tempfile.gettempdir()) 243 / f"{self.dataset.name}_{split_name}_{format_type}.jsonl" 244 ) 245 246 runs = self.task.runs() 247 runs_by_id = {run.id: run for run in runs} 248 249 # Generate formatted output with UTF-8 encoding 250 with open(output_path, "w", encoding="utf-8") as f: 251 for run_id in self.dataset.split_contents[split_name]: 252 task_run = runs_by_id[run_id] 253 if task_run is None: 254 raise ValueError( 255 f"Task run {run_id} not found. This is required by this dataset." 256 ) 257 258 example = generator(task_run, self.system_message) 259 # Allow non-ascii characters in the dataset. 260 # Better readability for non-English users. If you don't support UTF-8... you should. 261 f.write(json.dumps(example, ensure_ascii=False) + "\n") 262 263 return output_path
12class DatasetFormat(str, Enum): 13 """Formats for dataset generation. Both for file format (like JSONL), and internal structure (like chat/toolcall)""" 14 15 """OpenAI chat format with plaintext response""" 16 OPENAI_CHAT_JSONL = "openai_chat_jsonl" 17 18 """OpenAI chat format with json response_format""" 19 OPENAI_CHAT_JSON_SCHEMA_JSONL = "openai_chat_json_schema_jsonl" 20 21 """OpenAI chat format with tool call response""" 22 OPENAI_CHAT_TOOLCALL_JSONL = "openai_chat_toolcall_jsonl" 23 24 """HuggingFace chat template in JSONL""" 25 HUGGINGFACE_CHAT_TEMPLATE_JSONL = "huggingface_chat_template_jsonl" 26 27 """HuggingFace chat template with tool calls in JSONL""" 28 HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL = ( 29 "huggingface_chat_template_toolcall_jsonl" 30 ) 31 32 """Vertex Gemini 1.5 format (flash and pro)""" 33 VERTEX_GEMINI_1_5 = "vertex_gemini_1_5"
Formats for dataset generation. Both for file format (like JSONL), and internal structure (like chat/toolcall)
OpenAI chat format with json response_format
OpenAI chat format with tool call response
HuggingFace chat template in JSONL
HuggingFace chat template with tool calls in JSONL
Vertex Gemini 1.5 format (flash and pro)
36class FormatGenerator(Protocol): 37 """Protocol for format generators""" 38 39 def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ...
Protocol for format generators
1941def _no_init_or_replace_init(self, *args, **kwargs): 1942 cls = type(self) 1943 1944 if cls._is_protocol: 1945 raise TypeError('Protocols cannot be instantiated') 1946 1947 # Already using a custom `__init__`. No need to calculate correct 1948 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1949 if cls.__init__ is not _no_init_or_replace_init: 1950 return 1951 1952 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1953 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1954 # searches for a proper new `__init__` in the MRO. The new `__init__` 1955 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1956 # instantiation of the protocol subclass will thus use the new 1957 # `__init__` and no longer call `_no_init_or_replace_init`. 1958 for base in cls.__mro__: 1959 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1960 if init is not _no_init_or_replace_init: 1961 cls.__init__ = init 1962 break 1963 else: 1964 # should not happen 1965 cls.__init__ = object.__init__ 1966 1967 cls.__init__(self, *args, **kwargs)
42def best_task_output(task_run: TaskRun) -> str: 43 """Get the best task output from the task run, preferring repaired output if available""" 44 if task_run.repaired_output is not None: 45 return task_run.repaired_output.output 46 return task_run.output.output
Get the best task output from the task run, preferring repaired output if available
49def generate_chat_message_response( 50 task_run: TaskRun, system_message: str 51) -> Dict[str, Any]: 52 """Generate OpenAI chat format with plaintext response""" 53 return { 54 "messages": [ 55 {"role": "system", "content": system_message}, 56 {"role": "user", "content": task_run.input}, 57 {"role": "assistant", "content": best_task_output(task_run)}, 58 ] 59 }
Generate OpenAI chat format with plaintext response
62def generate_json_schema_message( 63 task_run: TaskRun, system_message: str 64) -> Dict[str, Any]: 65 """Generate OpenAI chat format with tool call response""" 66 # Load and dump to ensure it's valid JSON and goes to 1 line 67 try: 68 json_data = json.loads(best_task_output(task_run)) 69 except json.JSONDecodeError as e: 70 raise ValueError(f"Invalid JSON in for tool call: {e}") from e 71 json_string = json.dumps(json_data) 72 73 return { 74 "messages": [ 75 {"role": "system", "content": system_message}, 76 {"role": "user", "content": task_run.input}, 77 {"role": "assistant", "content": json_string}, 78 ] 79 }
Generate OpenAI chat format with tool call response
82def generate_chat_message_toolcall( 83 task_run: TaskRun, system_message: str 84) -> Dict[str, Any]: 85 """Generate OpenAI chat format with tool call response""" 86 try: 87 arguments = json.loads(best_task_output(task_run)) 88 except json.JSONDecodeError as e: 89 raise ValueError(f"Invalid JSON in for tool call: {e}") from e 90 91 return { 92 "messages": [ 93 {"role": "system", "content": system_message}, 94 {"role": "user", "content": task_run.input}, 95 { 96 "role": "assistant", 97 "content": None, 98 "tool_calls": [ 99 { 100 "id": "call_1", 101 "type": "function", 102 "function": { 103 "name": "task_response", 104 # Yes we parse then dump again. This ensures it's valid JSON, and ensures it goes to 1 line 105 "arguments": json.dumps(arguments, ensure_ascii=False), 106 }, 107 } 108 ], 109 }, 110 ] 111 }
Generate OpenAI chat format with tool call response
114def generate_huggingface_chat_template( 115 task_run: TaskRun, system_message: str 116) -> Dict[str, Any]: 117 """Generate HuggingFace chat template""" 118 return { 119 "conversations": [ 120 {"role": "system", "content": system_message}, 121 {"role": "user", "content": task_run.input}, 122 {"role": "assistant", "content": best_task_output(task_run)}, 123 ] 124 }
Generate HuggingFace chat template
127def generate_huggingface_chat_template_toolcall( 128 task_run: TaskRun, system_message: str 129) -> Dict[str, Any]: 130 """Generate HuggingFace chat template with tool calls""" 131 try: 132 arguments = json.loads(best_task_output(task_run)) 133 except json.JSONDecodeError as e: 134 raise ValueError(f"Invalid JSON in for tool call: {e}") from e 135 136 # See https://huggingface.co/docs/transformers/en/chat_templating 137 return { 138 "conversations": [ 139 {"role": "system", "content": system_message}, 140 {"role": "user", "content": task_run.input}, 141 { 142 "role": "assistant", 143 "tool_calls": [ 144 { 145 "type": "function", 146 "function": { 147 "name": "task_response", 148 "id": str(uuid4()).replace("-", "")[:9], 149 "arguments": arguments, 150 }, 151 } 152 ], 153 }, 154 ] 155 }
Generate HuggingFace chat template with tool calls
158def generate_vertex_gemini_1_5( 159 task_run: TaskRun, system_message: str 160) -> Dict[str, Any]: 161 """Generate Vertex Gemini 1.5 format (flash and pro)""" 162 # See https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare 163 return { 164 "systemInstruction": { 165 "role": "system", 166 "parts": [ 167 { 168 "text": system_message, 169 } 170 ], 171 }, 172 "contents": [ 173 { 174 "role": "user", 175 "parts": [ 176 { 177 "text": task_run.input, 178 } 179 ], 180 }, 181 { 182 "role": "model", 183 "parts": [ 184 { 185 "text": best_task_output(task_run), 186 } 187 ], 188 }, 189 ], 190 }
Generate Vertex Gemini 1.5 format (flash and pro)
203class DatasetFormatter: 204 """Handles formatting of datasets into various output formats""" 205 206 def __init__(self, dataset: DatasetSplit, system_message: str): 207 self.dataset = dataset 208 self.system_message = system_message 209 210 task = dataset.parent_task() 211 if task is None: 212 raise ValueError("Dataset has no parent task") 213 self.task = task 214 215 def dump_to_file( 216 self, split_name: str, format_type: DatasetFormat, path: Path | None = None 217 ) -> Path: 218 """ 219 Format the dataset into the specified format. 220 221 Args: 222 split_name: Name of the split to dump 223 format_type: Format to generate the dataset in 224 path: Optional path to write to. If None, writes to temp directory 225 226 Returns: 227 Path to the generated file 228 229 Note: 230 The output is written in UTF-8 encoding with ensure_ascii=False to properly 231 support international text content while maintaining readability. 232 """ 233 if format_type not in FORMAT_GENERATORS: 234 raise ValueError(f"Unsupported format: {format_type}") 235 if split_name not in self.dataset.split_contents: 236 raise ValueError(f"Split {split_name} not found in dataset") 237 238 generator = FORMAT_GENERATORS[format_type] 239 240 # Write to a temp file if no path is provided 241 output_path = ( 242 path 243 or Path(tempfile.gettempdir()) 244 / f"{self.dataset.name}_{split_name}_{format_type}.jsonl" 245 ) 246 247 runs = self.task.runs() 248 runs_by_id = {run.id: run for run in runs} 249 250 # Generate formatted output with UTF-8 encoding 251 with open(output_path, "w", encoding="utf-8") as f: 252 for run_id in self.dataset.split_contents[split_name]: 253 task_run = runs_by_id[run_id] 254 if task_run is None: 255 raise ValueError( 256 f"Task run {run_id} not found. This is required by this dataset." 257 ) 258 259 example = generator(task_run, self.system_message) 260 # Allow non-ascii characters in the dataset. 261 # Better readability for non-English users. If you don't support UTF-8... you should. 262 f.write(json.dumps(example, ensure_ascii=False) + "\n") 263 264 return output_path
Handles formatting of datasets into various output formats
215 def dump_to_file( 216 self, split_name: str, format_type: DatasetFormat, path: Path | None = None 217 ) -> Path: 218 """ 219 Format the dataset into the specified format. 220 221 Args: 222 split_name: Name of the split to dump 223 format_type: Format to generate the dataset in 224 path: Optional path to write to. If None, writes to temp directory 225 226 Returns: 227 Path to the generated file 228 229 Note: 230 The output is written in UTF-8 encoding with ensure_ascii=False to properly 231 support international text content while maintaining readability. 232 """ 233 if format_type not in FORMAT_GENERATORS: 234 raise ValueError(f"Unsupported format: {format_type}") 235 if split_name not in self.dataset.split_contents: 236 raise ValueError(f"Split {split_name} not found in dataset") 237 238 generator = FORMAT_GENERATORS[format_type] 239 240 # Write to a temp file if no path is provided 241 output_path = ( 242 path 243 or Path(tempfile.gettempdir()) 244 / f"{self.dataset.name}_{split_name}_{format_type}.jsonl" 245 ) 246 247 runs = self.task.runs() 248 runs_by_id = {run.id: run for run in runs} 249 250 # Generate formatted output with UTF-8 encoding 251 with open(output_path, "w", encoding="utf-8") as f: 252 for run_id in self.dataset.split_contents[split_name]: 253 task_run = runs_by_id[run_id] 254 if task_run is None: 255 raise ValueError( 256 f"Task run {run_id} not found. This is required by this dataset." 257 ) 258 259 example = generator(task_run, self.system_message) 260 # Allow non-ascii characters in the dataset. 261 # Better readability for non-English users. If you don't support UTF-8... you should. 262 f.write(json.dumps(example, ensure_ascii=False) + "\n") 263 264 return output_path
Format the dataset into the specified format.
Args: split_name: Name of the split to dump format_type: Format to generate the dataset in path: Optional path to write to. If None, writes to temp directory
Returns: Path to the generated file
Note: The output is written in UTF-8 encoding with ensure_ascii=False to properly support international text content while maintaining readability.