power_cogs.tune package¶
Submodules¶
power_cogs.tune.tune_wrapper module¶
-
class
power_cogs.tune.tune_wrapper.
CustomMLflowLoggerCallback
(*args, **kwargs)[source]¶ Bases:
ray.tune.integration.mlflow.MLflowLoggerCallback
-
on_trial_save
(iteration: int, trials: List[ray.tune.trial.Trial], trial: ray.tune.trial.Trial, **info)[source]¶ Called after receiving a checkpoint from a trial.
- Arguments:
- iteration (int): Number of iterations of the tuning loop. trials (List[Trial]): List of trials. trial (Trial): Trial that just saved a checkpoint. **info: Kwargs dict for forward compatibility.
-
-
class
power_cogs.tune.tune_wrapper.
TuneTrainer
(config=None, logger_creator=None)[source]¶ Bases:
ray.tune.trainable.Trainable
-
load_checkpoint
(tmp_checkpoint_dir)[source]¶ Subclasses should override this to implement restore().
- Warning:
- In this method, do not rely on absolute paths. The absolute
path of the checkpoint_dir used in
Trainable.save_checkpoint
may be changed.
If
Trainable.save_checkpoint
returned a prefixed string, the prefix of the checkpoint string returned byTrainable.save_checkpoint
may be changed. This is because trial pausing depends on temporary directories.The directory structure under the checkpoint_dir provided to
Trainable.save_checkpoint
is preserved.See the example below.
class Example(Trainable): def save_checkpoint(self, checkpoint_path): print(checkpoint_path) return os.path.join(checkpoint_path, "my/check/point") def load_checkpoint(self, checkpoint): print(checkpoint) >>> trainer = Example() >>> obj = trainer.save_to_object() # This is used when PAUSED. <logdir>/tmpc8k_c_6hsave_to_object/checkpoint_0/my/check/point >>> trainer.restore_from_object(obj) # Note the different prefix. <logdir>/tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point
New in version 0.8.7.
- Args:
- checkpoint (str|dict): If dict, the return value is as
- returned by save_checkpoint. If a string, then it is a checkpoint path that may have a different prefix than that returned by save_checkpoint. The directory structure underneath the checkpoint_dir save_checkpoint is preserved.
-
save_checkpoint
(tmp_checkpoint_dir)[source]¶ Subclasses should override this to implement
save()
.- Warning:
- Do not rely on absolute paths in the implementation of
Trainable.save_checkpoint
andTrainable.load_checkpoint
.
Use
validate_save_restore
to catchTrainable.save_checkpoint
/Trainable.load_checkpoint
errors before execution.>>> from ray.tune.utils import validate_save_restore >>> validate_save_restore(MyTrainableClass) >>> validate_save_restore(MyTrainableClass, use_object_store=True)
New in version 0.8.7.
- Args:
- tmp_checkpoint_dir (str): The directory where the checkpoint
- file must be stored. In a Tune run, if the trial is paused, the provided path may be temporary and moved.
- Returns:
- A dict or string. If string, the return value is expected to be
prefixed by tmp_checkpoint_dir. If dict, the return value will
be automatically serialized by Tune and
passed to
Trainable.load_checkpoint()
. - Examples:
>>> print(trainable1.save_checkpoint("/tmp/checkpoint_1")) "/tmp/checkpoint_1/my_checkpoint_file" >>> print(trainable2.save_checkpoint("/tmp/checkpoint_2")) {"some": "data"}
>>> trainable.save_checkpoint("/tmp/bad_example") "/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.
-
setup
(config)[source]¶ Subclasses should override this for custom initialization.
New in version 0.8.7.
- Args:
- config (dict): Hyperparameters and other configs given.
- Copy of self.config.
-
step
()[source]¶ Subclasses should override this to implement train().
The return value will be automatically passed to the loggers. Users can also return tune.result.DONE or tune.result.SHOULD_CHECKPOINT as a key to manually trigger termination or checkpointing of this trial. Note that manual checkpointing only works when subclassing Trainables.
New in version 0.8.7.
- Returns:
- A dict that describes training progress.
-