from sconce.schedules.base import ScheduledMixin
[docs]class ParameterGroup(ScheduledMixin):
"""
A parameter group is the way that sconce models organize nn.Module parameters and their associated optimizers.
Arguments:
parameters (iterable of :py:class:`torch.nn.Parameter`): the parameters you want to group together.
name (string): your name for this group
is_active (bool, optional): should this group be considered active (used during training)?
"""
def __init__(self, parameters, name, is_active=True):
super().__init__()
self.parameters = parameters
self.name = name
self.optimizer = None
self.is_active = is_active
[docs] def set_optimizer(self, optimizer_class, *args, **kwargs):
"""
Set an optimizer on this parameter group. If this parameter group is active (has ``is_active=True``) then this
optimizer will be used during training.
Arguments:
optimizer_class (one of the :py:mod:`torch.optim` classes): the class of optimizer to set.
Note:
All other arguments and keyword arguments are delivered to the optimizer_class's constructor.
"""
self.optimizer = optimizer_class(params=self.parameters, *args, **kwargs)
return self.optimizer
[docs] def set_learning_rate(self, desired_learning_rate):
param_groups = self.optimizer.param_groups
for param_group in param_groups:
param_group['lr'] = desired_learning_rate
return desired_learning_rate
[docs] def set_momentum(self, desired_momentum):
param_groups = self.optimizer.param_groups
for param_group in param_groups:
param_group['momentum'] = desired_momentum
return desired_momentum
[docs] def set_weight_decay(self, desired_weight_decay):
param_groups = self.optimizer.param_groups
for param_group in param_groups:
param_group['weight_decay'] = desired_weight_decay
return desired_weight_decay
[docs] def freeze(self):
"""
Set ``requires_grad = False`` for all parameters in this group.
"""
for parameter in self.parameters:
parameter.requires_grad = False
[docs] def unfreeze(self):
"""
Set ``requires_grad = True`` for all parameters in this group.
"""
for parameter in self.parameters:
parameter.requires_grad = True