Coverage for src/logfire_callback/callback.py: 97%

35 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-23 13:43 +0100

1import importlib.util 

2import json 

3import os 

4 

5import logfire 

6from transformers import ( 

7 AutoModelForMaskedLM, 

8 TrainerCallback, 

9 TrainerControl, 

10 TrainerState, 

11 TrainingArguments, 

12) 

13 

14 

15def is_logfire_available() -> bool: 

16 """Check if the logfire package is available. 

17 

18 Returns: 

19 bool: True if logfire is available, False otherwise. 

20 """ 

21 return importlib.util.find_spec("logfire") is not None 

22 

23 

24class LogfireCallback(TrainerCallback): 

25 """A callback for logging training events to Logfire. 

26 

27 This callback integrates with the Logfire logging service to track training progress, metrics, and events during model training. It inherits from HuggingFace's TrainerCallback and logs important events like training start, end, and periodic metric updates. 

28 

29 Note: 

30 Requires the `logfire` package to be installed. Install with `pip install logfire`. A Logfire API token must be set in the LOGFIRE_TOKEN environment variable. 

31 

32 Attributes: 

33 _logfire: The logfire module instance for making logging calls. 

34 _logfire_token: The API token for Logfire authentication. 

35 _initialized: Boolean flag indicating if the callback has been initialized. 

36 """ 

37 

38 def __init__(self) -> None: 

39 """Initialize the LogfireCallback. 

40 

41 Raises: 

42 RuntimeError: If the logfire package is not installed. 

43 """ 

44 if not is_logfire_available(): 

45 raise RuntimeError( 

46 "LogfireCallback requires `logfire` to be installed. Run `pip install logfire`." 

47 ) 

48 

49 self._logfire = logfire 

50 self._logfire_token = os.getenv("LOGFIRE_TOKEN", None) 

51 self._initialized = False 

52 

53 self._logfire.configure( 

54 token=self._logfire_token, console=False, inspect_arguments=False 

55 ) 

56 

57 from typing import Any, Optional 

58 

59 def on_train_begin( 

60 self, 

61 args: TrainingArguments, 

62 state: TrainerState, 

63 control: TrainerControl, 

64 model: AutoModelForMaskedLM | None = None, 

65 **kwargs: dict[str, Any], 

66 ) -> None: 

67 """Called when training begins. 

68 

69 Logs the initial training parameters and configuration to Logfire. 

70 

71 Args: 

72 args (TrainingArguments): Training arguments and hyperparameters. 

73 state (TrainerState): Current state of the trainer. 

74 control (TrainerControl): Training control object. 

75 model (AutoModelForMaskedLM | None, optional): The model being trained. Defaults to None. 

76 **kwargs (dict[str, Any]): Additional keyword arguments. 

77 

78 Note: 

79 Only logs when running on the main process (is_local_process_zero) to avoid duplicate logs in distributed training. 

80 """ 

81 if self._logfire and state.is_local_process_zero: 

82 

83 def make_serializable(obj: TrainingArguments) -> object: 

84 """Convert training arguments to a JSON-serializable format. 

85 

86 Args: 

87 obj (TrainingArguments): The object to make serializable. 

88 

89 Returns: 

90 object: A JSON-serializable version of the input object. 

91 """ 

92 if hasattr(obj, "__dict__"): 

93 return { 

94 k: make_serializable(v) 

95 for k, v in obj.__dict__.items() 

96 if not k.startswith("_") and not callable(v) 

97 } 

98 elif isinstance(obj, list | tuple): 

99 return [make_serializable(x) for x in obj] 

100 elif isinstance(obj, dict): 

101 return {k: make_serializable(v) for k, v in obj.items()} 

102 elif isinstance(obj, int | float | str | bool | type(None)): 

103 return obj 

104 else: 

105 return str(obj) 

106 

107 args_dict = make_serializable(args) 

108 

109 self._logfire.info( 

110 "Training started with the following parameters: {args}", 

111 args=json.dumps(args_dict), 

112 ) 

113 

114 def on_train_end( 

115 self, 

116 args: TrainingArguments, 

117 state: TrainerState, 

118 control: TrainerControl, 

119 **kwargs: dict[str, Any], 

120 ) -> None: 

121 """Called when training ends. 

122 

123 Logs the completion of training to Logfire. 

124 

125 Args: 

126 args (TrainingArguments): Training arguments and hyperparameters. 

127 state (TrainerState): Final state of the trainer. 

128 control (TrainerControl): Training control object. 

129 **kwargs (dict[str, Any]): Additional keyword arguments. 

130 

131 Note: 

132 Only logs when running on the main process (is_local_process_zero) to avoid 

133 duplicate logs in distributed training. 

134 """ 

135 if self._logfire and state.is_local_process_zero: 

136 self._logfire.info( 

137 "Training successfully completed.", 

138 ) 

139 

140 def on_log( 

141 self, 

142 args: TrainingArguments, 

143 state: TrainerState, 

144 control: TrainerControl, 

145 logs: dict[str, Any] | None = None, 

146 **kwargs: dict[str, Any], 

147 ) -> None: 

148 """Called when training metrics are logged. 

149 

150 Sends training metrics and logs to Logfire for monitoring and visualization. 

151 

152 Args: 

153 args (TrainingArguments): Training arguments and hyperparameters. 

154 state (TrainerState): Current state of the trainer. 

155 control (TrainerControl): Training control object. 

156 logs (dict[str, Any] | None, optional): Dictionary of logs to record. Defaults to None. 

157 **kwargs (dict[str, Any]): Additional keyword arguments. 

158 

159 Note: 

160 Only logs when running on the main process (is_local_process_zero) to avoid duplicate logs in distributed training. 

161 

162 The logs dictionary typically contains metrics such as: 

163 - loss: Training loss 

164 - learning_rate: Current learning rate 

165 - epoch: Current epoch number 

166 - step: Current training step 

167 - eval_loss: Evaluation loss (during evaluation) 

168 - eval_runtime: Evaluation runtime 

169 - eval_samples_per_second: Evaluation throughput 

170 - eval_steps_per_second: Evaluation steps per second 

171 """ 

172 if self._logfire and state.is_local_process_zero: 

173 self._logfire.info( 

174 "{logs}", 

175 logs=logs, 

176 )