Source code for thelper.optim.schedulers
"""Schedulers module.
This module contains classes used for scheduling learning rate changes while training a model. All
classes defined here should derive from ``torch.optim.lr_scheduler._LRScheduler`` to remain torch-
compatible.
"""
import bisect
import logging
import torch.nn.functional
logger = logging.getLogger(__name__)
[docs]class CustomStepLR(torch.optim.lr_scheduler._LRScheduler):
"""Sets the learning rate of each parameter group using a dictionary of preset scaling factors
for epoch-based milestones.
This class can be useful for tuning the learning rate scheduling behavior of a training session
beyond what is already possible using PyTorch's existing LR scheduler classes. Note that all
epoch indices are assumed to be 0-based.
Usage example in Python::
# Assuming the optimizer uses lr = 0.05, we hard-code a slow startup...
# lr = 0.00625 if epoch < 2 (1/8 scale before epoch 2)
# lr = 0.0125 if 2 <= epoch < 3 (1/4 scale before epoch 3)
# lr = 0.025 if 3 <= epoch < 4 (1/2 scale before epoch 4)
# lr = 0.05 if 4 <= epoch < 30 (default scale between epoch 4 and 30)
# lr = 0.005 if 30 <= epoch < 80 (1/10 scale past epoch 30)
# lr = 0.0005 if epoch >= 80 (1/100 scale past epoch 80)
scheduler = CustomStepLR(optimizer, milestones={
0: 1/8,
2: 1/4,
3: 1/2,
4: 1,
30: 0.1,
80: 0.01
})
for epoch in range(100):
scheduler.step(epoch)
train(...)
validate(...)
Usage example inside a session configuration file::
# ...
# lists the model optimization parameters for the training session
"optimization": {
# lists the optimizer arguments (type, parameters, LR, ...)
"optimizer": {
# ...
},
# lists the scheduler arguments (field can be omitted if no scheduler is needed)
"scheduler": {
# the type used to instantiate the scheduler
"type": "thelper.optim.schedulers.CustomStepLR",
# the parameters passed to the scheduler's constructor
"params": {
# by default, the optimizer is passed automatically;
# we only need to specify the extra parameters here
"milestones": {
"1": 1, # after epoch 1, scale the LR by 1
"10": 0.1, # after epoch 10, scale the LR by 0.1
"20": 0.01, # ... and so on
"30": 0.001,
"40": 0.0001
}
}
}
},
# ...
Attributes:
stages: list of epochs where a new scaling factor is to be applied.
scales: list of scaling factors to apply at each stage.
milestones: original milestones map provided in the constructor.
.. seealso::
| :class:`thelper.train.base.Trainer`
"""
[docs] def __init__(self, optimizer, milestones, last_epoch=-1):
"""Receives the optimizer, milestone scaling factor, and initialization state.
If the milestones do not include the first epoch (idx = 0), then its scaling factor is set
to 1. When last_epoch is -1, the training is assumed to start from scratch.
Args:
optimizer: Wrapped optimizer (PyTorch-compatible object).
milestones: Map of epoch indices tied to scaling factors. Keys must be increasing.
last_epoch: The index of last epoch. Default: -1.
"""
if not isinstance(milestones, dict):
raise AssertionError("milestones should be provided as a dictionary")
self.stages = []
if len(milestones) > 0:
if isinstance(list(milestones.keys())[0], str): # fixup for json-based config loaders
self.stages = [int(key) for key in milestones.keys()]
elif isinstance(list(milestones.keys())[0], int):
self.stages = list(milestones.keys())
else:
raise AssertionError("milestone stages should be epoch indices (integers)")
if self.stages != sorted(self.stages):
raise AssertionError("milestone stages should be increasing integers")
if not isinstance(list(milestones.values())[0], (float, int)):
raise AssertionError("milestone scaling factors should be int/float")
self.scales = [float(scale) for scale in milestones.values()]
if 0 not in self.stages:
self.stages.insert(0, int(1))
self.scales.insert(0, float(1))
self.milestones = milestones
super().__init__(optimizer, last_epoch)
def _get_stage_idx(self, epoch):
if epoch in self.stages:
return self.stages.index(epoch)
return max(bisect.bisect_right(self.stages, epoch) - 1, 0)
[docs] def get_lr(self):
"""Returns the learning rate to use given the current epoch and scaling factors."""
scale = self.scales[self._get_stage_idx(self.last_epoch)]
return [base_lr * scale for base_lr in self.base_lrs]