Source code for sconce.rate_controllers.constant_rate_controller

from sconce.rate_controllers.base import RateController

from sconce.monitors.ringbuffer_monitor import RingbufferMonitor


[docs]class ConstantRateController(RateController): """ A Learning rate that is constant. It can adjust its learning rate by <drop_factor> up to <num_drops> times based on detecting that some metric or loss has stopped moving. """ def __init__(self, learning_rate, drop_factor=0.1, movement_key='training_loss', movement_threshold=0.25, movement_window=None, num_drops=0): self.learning_rate = learning_rate self.movement_window = movement_window self.movement_key = movement_key self.num_drops = num_drops self.drop_factor = drop_factor self.movement_threshold = movement_threshold self.monitor = None self.num_drops_taken = 0 self.factor = 1.0
[docs] def start_session(self, *args): if self.movement_window is not None: self.reset_monitor()
[docs] def reset_monitor(self): self.monitor = RingbufferMonitor(capacity=2 * self.movement_window, key=self.movement_key)
[docs] def new_learning_rate(self, step, data): if self.monitor is not None: self.monitor.write(data=data, step=step) movement_index = self.monitor.movement_index if movement_index is not None: if movement_index < self.movement_threshold: if self.num_drops_taken < self.num_drops: self.reset_monitor() self.num_drops_taken += 1 self.factor *= self.drop_factor else: return None return self.factor * self.learning_rate