Source code for sconce.models.base

from abc import ABC, abstractmethod
from sconce.parameter_group import ParameterGroup
from sconce.schedules.base import ScheduledMixin
from torch import nn

import numpy as np


[docs]class Model(ABC, nn.Module, ScheduledMixin): """ The base class of all Models in Sconce. It is only an interface, describing what must be implemented if you want to define a model. """ def __init__(self): super(ABC, self).__init__() super(nn.Module, self).__init__() super(ScheduledMixin, self).__init__() self._parameter_groups = {} self.default_parameter_group_name = '__default__'
[docs] def build_parameter_groups(self): """ This can be overridden to build additional parameter groups. This can be useful if you're doing different layerwise optimization schedules. """ parameters = self.get_trainable_parameters() group = ParameterGroup(parameters=parameters, name=self.default_parameter_group_name) self.add_parameter_group(group)
[docs] def add_parameter_group(self, group, inactivate_default=True): """ Add a new parameter group to this model. Arguments: group (:py:class:`~sconce.parameter_group.ParameterGroup`): the parameter group to add. inactivate_default (bool, optional): if ``True``, then the default parameter group will have ``is_active`` set to ``False``. """ self._parameter_groups[group.name] = group if inactivate_default and group.name != self.default_parameter_group_name: self.default_parameter_group.is_active = False
[docs] @abstractmethod def forward(self, *, inputs, targets, **kwargs): """ It must accept arbitrary keyword arguments. The base class of trainer will pass `inputs` and `targets`, but subclasses may modify that behavior to include other keyword arguments. It must return a dictionary. The dictionary is expected to include at least the key `outputs` but may include any other keys you like. The value of the key `outputs` is expected to be the :py:class:`torch.Tensor` output of the model, used for calculating the loss. """
[docs] @abstractmethod def calculate_loss(self, *, inputs, outputs, targets, **kwargs): """ This method must accept arbitrary keyword arguments. The base class of trainer will pass `inputs`, `outputs`, and `targets`, but subclasses may modify that behavior to include other keyword arguments. It must return a dictionary. The dictionary is expected to include at least the key 'loss', but may include any otehr keys you like. The value of the key `loss` is expected to be the :py:class:`torch.Tensor` output of the loss function, used to back-propagate the gradients used by the optimizer. """
[docs] def calculate_metrics(self, *, inputs, outputs, targets, loss, **kwargs): """ This method must accept arbitrary keyword arguments. The base class of trainer will pass `inputs`, `outputs`, `targets`, and `loss`, but subclasses may modify that behavior to include other keyword arguments. It must return a dictionary. No restrictions are made on the keys or values of this dictionary. """ return {}
[docs] def get_optimizers(self): """ Returns a list of optimizers for the parameters of this model. """ result = [] for group in self.active_parameter_groups: if group.optimizer is not None: result.append(group.optimizer) if not result: raise RuntimeError("No active parameter groups with optimizers found. " "Did you add an optimizer with 'set_optimizer'?") return result
[docs] def get_trainable_parameters(self): """ The trainable parameters that the models has. """ return list(filter(lambda p: p.requires_grad, self.parameters()))
[docs] def get_num_trainable_parameters(self): """ The number of trainable parameters that the models has. """ return sum([np.prod(p.size()) for p in self.get_trainable_parameters()])
[docs] def prepare_for_step(self, step, current_state): """ First, it handles any hyperparameter schedules added to the model itself before gathering up the results of calling 'prepare_for_step' on all the model's parameter groups and combining the result. """ model_hyperparameters = super().prepare_for_step(step=step, current_state=current_state) hyperparameters = {'model': model_hyperparameters} for group in self.active_parameter_groups: group_hyperparameters = group.prepare_for_step(step=step, current_state=current_state) hyperparameters[group.name] = group_hyperparameters return hyperparameters
[docs] def set_schedule(self, name, schedule): """ Set the schedule for a hyperparameter on this model. Arguments: name (string): the name of the hyperparameter you want to schedule. schedule (:py:class:~sconce.schedules.base.Schedule): the schedule for that hyperparameter. Note: Some name values are interpreted specially. Setting name to 'learning_rate', 'momentum', or 'weight_decay' will delegate to setting schedules on all active parameter groups instead of on the model. """ if name in ('learning_rate', 'momentum', 'weight_decay'): for group in self.active_parameter_groups: group.set_schedule(name=name, schedule=schedule) else: super().set_schedule(name=name, schedule=schedule)
@property def parameter_groups(self): """ A list of all parameter groups, inactive as well as active. """ if not self._parameter_groups: self.build_parameter_groups() return self._parameter_groups.values()
[docs] def get_parameter_group(self, name): """ Get a parameter group by name. """ if not self._parameter_groups: self.build_parameter_groups() return self._parameter_groups[name]
@property def default_parameter_group(self): """ The default parameter group is created automatically and includes all of the trainable parameters for the model. """ return self.get_parameter_group(self.default_parameter_group_name) @property def active_parameter_groups(self): """ A list of all active parameter groups. """ return [g for g in self.parameter_groups if g.is_active]
[docs] def set_optimizer(self, *args, **kwargs): """ Set the optimizer for all of the active parameter groups on this model. """ for group in self.active_parameter_groups: group.set_optimizer(*args, **kwargs)
[docs] def start_session(self, num_steps): """ Called by the :py:class:~sconce.trainer.Trainer when a training session starts. Arguments: num_steps (int): the number of steps the trainer will take during this training session. """ super().start_session(num_steps) for group in self.active_parameter_groups: group.start_session(num_steps)
[docs] def print_schedule_summary(self): """ Print out a summary of the scheduled hyperparameters on this model and it's parameter groups. """ for name, schedule in self.schedules.items(): print(f'model.{name}: {schedule}') for group in self.parameter_groups: for schedule_name, schedule in group.schedules.items(): if not group.is_active: print('(inactive) ', end='') print(f'{group.name}.{schedule_name}: {schedule}')