Source code for sconce.rate_controllers.step_rate_controller

from sconce.rate_controllers.base import RateController

import numpy as np


[docs]class StepRateController(RateController): """ A Learning Rate that falls in <num_drops> drops from <max_learning_rate> to <min_learning_rate> over the course of <num_steps>. The Learning Rate is constant between drops. """ def __init__(self, max_learning_rate, min_learning_rate, num_drops=1): self.min_learning_rate = min_learning_rate self.max_learning_rate = max_learning_rate self.num_regions = num_drops + 1 self.region_values = np.linspace(max_learning_rate, min_learning_rate, self.num_regions) self.num_steps = None
[docs] def start_session(self, num_steps): self.num_steps = num_steps self.idxs = [int(i) for i in np.linspace(0, self.num_regions - 1e-12, self.num_steps)]
[docs] def new_learning_rate(self, step, data): if self.num_steps is None: raise RuntimeError("You must call 'start_session' before calling " "'new_learning_rate'") if step > self.num_steps: raise RuntimeError(f"Argument step={step}, should not " f"exceed num_steps={self.num_steps}") return self.region_values[self.idxs[step - 1]]