--- title: Callbacks keywords: fastai sidebar: home_sidebar summary: "API details." description: "API details." ---
{% raw %}
{% endraw %} {% raw %}
%load_ext autoreload
%autoreload 2
%matplotlib inline
{% endraw %} {% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class TorchCallback[source]

TorchCallback()

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class BasicConfig[source]

BasicConfig(order=0) :: TorchCallback

Handles basic model tasks like putting the model on the GPU
and switching between train and eval modes.
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class StatsHandler[source]

StatsHandler(order=5) :: TorchCallback

This updates metrics at the end of each epoch to account for
potentially varying batch sizes.
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class MetricPrinter[source]

MetricPrinter(pbar_metric='loss', batch_freq=1, order=10) :: TorchCallback

Prints metrics at the end of each epoch. This is one of the
default callbacks provided in BaseModel - it does not need to
be passed in explicitly.
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class BatchMetricPrinter[source]

BatchMetricPrinter(batch_freq, n_prints=inf, order=10) :: TorchCallback

Prints mini batch metrics to help us see if a model is
learning early in training (helpful for debugging). We
remove the callback after the specified number of prints
so that it isn't called unnecessarily throughout the whole
training process.
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class EarlyStopper[source]

EarlyStopper(metric, goal:('max', 'min'), min_improvement=0.0, patience=3, order=15) :: TorchCallback

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class PerformanceThreshold[source]

PerformanceThreshold(metric, goal:('min', 'max'), threshold, skip_epochs=0, split:('train', 'val')='val', order=15) :: TorchCallback

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class ModelCheckpoint[source]

ModelCheckpoint(metric='loss', goal:('max', 'min')='min', order=25) :: TorchCallback

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class MetricHistory[source]

MetricHistory(fname='history.csv', plot_fname='history.png', order=90) :: TorchCallback

Separate from StatsHandler in case we don't want to log outputs.
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class S3Uploader[source]

S3Uploader(bucket, prefix, order=95) :: TorchCallback

Upload model and logs to S3 when training finishes.
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class BotoS3Uploader[source]

BotoS3Uploader(bucket, s3_dir='', retain_tree=True, recurse=True, keep_fn=None, order=95) :: TorchCallback

Upload model and logs to S3 when training finishes. This version of the
callback does not rely on any GoGuardian packages.
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class EC2Closer[source]

EC2Closer(timeout=5, order=100) :: TorchCallback

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class ModelUnfreezer[source]

ModelUnfreezer(i2n, unfreeze_type:('groups', 'layers')='groups', mode:('batch', 'epoch')='epoch', order=25) :: TorchCallback

Gradually unfreeze a model during training.
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class SchedulerMixin[source]

SchedulerMixin() :: TorchCallback

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class CosineLRScheduler[source]

CosineLRScheduler(warm=0.3, restarts=False, cycle_len=5, cycle_decay=0.0, min_lr=None, verbose=False, order=10) :: SchedulerMixin

Learning rate scheduler that makes updates each batch.
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class AdaptiveSawtoothScheduler[source]

AdaptiveSawtoothScheduler(add=0.0001, scale=0.6, patience=5, order=10) :: SchedulerMixin

Learning rate scheduler inspired by the sawtooth pattern often
used to manage TCP flow
(ex: https://witestlab.poly.edu/blog/tcp-congestion-control-basics/).
This uses a strategy called "additive increase, multiplicative decrease".
Basically, while the training loss is generally decreasing, we
gradually increase the learning rate. When things show signs of getting
worse, we dramatically decrease the LR and begin slowly climbing again.
The result looks something like a cyclical policy with restarts,
except that in this case the cycle lengths are dependent on training
rather than pre-defined. SGD w/ restarts typically also uses a sharp
increase and a gradual decrease, while this is closer to the opposite.

Unlike the standard AIMD algorithm, we decay the amount added if the
batch loss increases, even if we're still within the patience window.
{% endraw %}

Here is a sample training run with an adaptive sawtooth scheduler. {% include image.html file="/incendio/adaptive_sawtooth_lrs.png" %}