Coverage for src/logfire_callback/callback.py: 97%
35 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-23 13:43 +0100
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-23 13:43 +0100
1import importlib.util
2import json
3import os
5import logfire
6from transformers import (
7 AutoModelForMaskedLM,
8 TrainerCallback,
9 TrainerControl,
10 TrainerState,
11 TrainingArguments,
12)
15def is_logfire_available() -> bool:
16 """Check if the logfire package is available.
18 Returns:
19 bool: True if logfire is available, False otherwise.
20 """
21 return importlib.util.find_spec("logfire") is not None
24class LogfireCallback(TrainerCallback):
25 """A callback for logging training events to Logfire.
27 This callback integrates with the Logfire logging service to track training progress, metrics, and events during model training. It inherits from HuggingFace's TrainerCallback and logs important events like training start, end, and periodic metric updates.
29 Note:
30 Requires the `logfire` package to be installed. Install with `pip install logfire`. A Logfire API token must be set in the LOGFIRE_TOKEN environment variable.
32 Attributes:
33 _logfire: The logfire module instance for making logging calls.
34 _logfire_token: The API token for Logfire authentication.
35 _initialized: Boolean flag indicating if the callback has been initialized.
36 """
38 def __init__(self) -> None:
39 """Initialize the LogfireCallback.
41 Raises:
42 RuntimeError: If the logfire package is not installed.
43 """
44 if not is_logfire_available():
45 raise RuntimeError(
46 "LogfireCallback requires `logfire` to be installed. Run `pip install logfire`."
47 )
49 self._logfire = logfire
50 self._logfire_token = os.getenv("LOGFIRE_TOKEN", None)
51 self._initialized = False
53 self._logfire.configure(
54 token=self._logfire_token, console=False, inspect_arguments=False
55 )
57 from typing import Any, Optional
59 def on_train_begin(
60 self,
61 args: TrainingArguments,
62 state: TrainerState,
63 control: TrainerControl,
64 model: AutoModelForMaskedLM | None = None,
65 **kwargs: dict[str, Any],
66 ) -> None:
67 """Called when training begins.
69 Logs the initial training parameters and configuration to Logfire.
71 Args:
72 args (TrainingArguments): Training arguments and hyperparameters.
73 state (TrainerState): Current state of the trainer.
74 control (TrainerControl): Training control object.
75 model (AutoModelForMaskedLM | None, optional): The model being trained. Defaults to None.
76 **kwargs (dict[str, Any]): Additional keyword arguments.
78 Note:
79 Only logs when running on the main process (is_local_process_zero) to avoid duplicate logs in distributed training.
80 """
81 if self._logfire and state.is_local_process_zero:
83 def make_serializable(obj: TrainingArguments) -> object:
84 """Convert training arguments to a JSON-serializable format.
86 Args:
87 obj (TrainingArguments): The object to make serializable.
89 Returns:
90 object: A JSON-serializable version of the input object.
91 """
92 if hasattr(obj, "__dict__"):
93 return {
94 k: make_serializable(v)
95 for k, v in obj.__dict__.items()
96 if not k.startswith("_") and not callable(v)
97 }
98 elif isinstance(obj, list | tuple):
99 return [make_serializable(x) for x in obj]
100 elif isinstance(obj, dict):
101 return {k: make_serializable(v) for k, v in obj.items()}
102 elif isinstance(obj, int | float | str | bool | type(None)):
103 return obj
104 else:
105 return str(obj)
107 args_dict = make_serializable(args)
109 self._logfire.info(
110 "Training started with the following parameters: {args}",
111 args=json.dumps(args_dict),
112 )
114 def on_train_end(
115 self,
116 args: TrainingArguments,
117 state: TrainerState,
118 control: TrainerControl,
119 **kwargs: dict[str, Any],
120 ) -> None:
121 """Called when training ends.
123 Logs the completion of training to Logfire.
125 Args:
126 args (TrainingArguments): Training arguments and hyperparameters.
127 state (TrainerState): Final state of the trainer.
128 control (TrainerControl): Training control object.
129 **kwargs (dict[str, Any]): Additional keyword arguments.
131 Note:
132 Only logs when running on the main process (is_local_process_zero) to avoid
133 duplicate logs in distributed training.
134 """
135 if self._logfire and state.is_local_process_zero:
136 self._logfire.info(
137 "Training successfully completed.",
138 )
140 def on_log(
141 self,
142 args: TrainingArguments,
143 state: TrainerState,
144 control: TrainerControl,
145 logs: dict[str, Any] | None = None,
146 **kwargs: dict[str, Any],
147 ) -> None:
148 """Called when training metrics are logged.
150 Sends training metrics and logs to Logfire for monitoring and visualization.
152 Args:
153 args (TrainingArguments): Training arguments and hyperparameters.
154 state (TrainerState): Current state of the trainer.
155 control (TrainerControl): Training control object.
156 logs (dict[str, Any] | None, optional): Dictionary of logs to record. Defaults to None.
157 **kwargs (dict[str, Any]): Additional keyword arguments.
159 Note:
160 Only logs when running on the main process (is_local_process_zero) to avoid duplicate logs in distributed training.
162 The logs dictionary typically contains metrics such as:
163 - loss: Training loss
164 - learning_rate: Current learning rate
165 - epoch: Current epoch number
166 - step: Current training step
167 - eval_loss: Evaluation loss (during evaluation)
168 - eval_runtime: Evaluation runtime
169 - eval_samples_per_second: Evaluation throughput
170 - eval_steps_per_second: Evaluation steps per second
171 """
172 if self._logfire and state.is_local_process_zero:
173 self._logfire.info(
174 "{logs}",
175 logs=logs,
176 )