Source code for power_cogs.tune.tune_wrapper

import os
from typing import Any, List

import attr
from hydra.utils import instantiate
from omegaconf import OmegaConf
from ray import tune
from ray.tune.integration.mlflow import MLflowLoggerCallback
from ray.tune.trial import Trial


[docs]class CustomMLflowLoggerCallback(MLflowLoggerCallback): def __init__(self, *args, **kwargs): super(CustomMLflowLoggerCallback, self).__init__(*args, **kwargs)
[docs] def on_trial_save(self, iteration: int, trials: List[Trial], trial: Trial, **info): """Called after receiving a checkpoint from a trial. Arguments: iteration (int): Number of iterations of the tuning loop. trials (List[Trial]): List of trials. trial (Trial): Trial that just saved a checkpoint. **info: Kwargs dict for forward compatibility. """ run_id = self._trial_runs[trial] # Log the artifact if set_artifact is set to True. self.client.log_artifacts(run_id, local_dir=trial.logdir)
[docs]class TuneTrainer(tune.Trainable):
[docs] def setup(self, config): overrides = config.get("overrides", {}) trainer_config = OmegaConf.structured(config.get("config")) trainer_config.trainer.config = trainer_config self.trainer = instantiate( trainer_config.trainer, _recursive_=False, **overrides )
[docs] def step(self): out = self.trainer.train_iter( self.trainer.batch_size, self.trainer.current_iteration ) self.trainer.current_iteration += 1 return out["metrics"]
[docs] def save_checkpoint(self, tmp_checkpoint_dir): return self.trainer.save(tmp_checkpoint_dir)
[docs] def load_checkpoint(self, tmp_checkpoint_dir): self.trainer.load( os.path.join(tmp_checkpoint_dir, "{}.pt".format(self.trainer.name)) )
[docs]def create_stopper(config): epochs = config.get("epochs", 1000) loss_threshold: float = -float("inf") if config.get("early_stoppage"): loss_threshold = config.get("loss_threshold", 0.005) def stopper(trial_id, result): if result["training_iteration"] >= epochs: return True return result["loss"] < loss_threshold return stopper
[docs]@attr.s class TuneWrapper: config: Any = attr.ib() trainer_config: Any = attr.ib() trainer_overrides: Any = attr.ib()
[docs] def tune(self): mlflow_callback = CustomMLflowLoggerCallback( experiment_name=self.trainer_config["trainer"]["name"], save_artifact=False ) callbacks = self.config.get("callbacks", []) callbacks.append(mlflow_callback) self.config["callbacks"] = callbacks if "stop" not in self.config: self.config["stop"] = create_stopper(self.trainer_config["trainer"]) return tune.run( TuneTrainer, config={"config": self.trainer_config, "overrides": self.trainer_overrides}, **self.config )