Source code for sconce.rate_controllers.triangle_rate_controller

from sconce.rate_controllers.base import RateController

import numpy as np


[docs]class TriangleRateController(RateController): """ A Learning Rate that rises linearly from <min_learning_rate> to <max_learning_rate>, over <num_steps>/2 then drops linearly back to <min_learning_rate> over the remaining <num_steps>/2. """ def __init__(self, min_learning_rate, max_learning_rate): self.min_learning_rate = min_learning_rate self.max_learning_rate = max_learning_rate self.learning_rates = None
[docs] def start_session(self, num_steps): # all fractions round up (instead of truncate) rise_steps = -(-num_steps // 2) rising_rates = np.linspace(self.min_learning_rate, self.max_learning_rate, rise_steps) falling_rates = np.linspace(self.max_learning_rate, self.min_learning_rate, (num_steps + 1) - rise_steps) self.learning_rates = np.concatenate((rising_rates, falling_rates[1:]))
[docs] def new_learning_rate(self, step, data): if self.learning_rates is None: raise RuntimeError("You must call 'start_session' before calling " "'new_learning_rate'") if step > len(self.learning_rates): raise RuntimeError(f"Argument step={step}, should not " f"exceed num_steps={len(self.learning_rates)}") return self.learning_rates[step - 1]