scNym/trainer Module

class scnym.trainer.Trainer(model, criterion, optimizer, dataloaders, out_path, batch_transformers={}, n_epochs=50, min_epochs=0, exp_name='', reg_criterion=None, use_gpu=False, verbose=False, save_freq=10, scheduler=None, tb_writer=None)

Bases: object

Trains a PyTorch model.

model

model with required .forward(…) method.

Type

nn.Module

criterion

loss criterion to optimize.

Type

Callable

optimizer

optimizer for the model parameters.

Type

torch.optim.Optimizer

dataloaders

keyed by [‘train’, ‘val’] with values corresponding to torch.utils.data.DataLoader for training and validation sets.

Type

dict

out_path

output path for best model.

Type

str

n_epochs

number of epochs for training.

Type

int

reg_criterion

criterion to penalize layer weights.

Type

Callable

use_gpu

use CUDA acceleration.

Type

bool

verbose

write all batch losses to stdout.

Type

bool

save_freq

Number of epochs between model checkpoints. Default = 10.

Type

int

scheduler
Type

learning rate scheduler.

train_epoch()

Perform training across one full iteration through the data.

val_epoch()

Perform a pass through the validation data. Do not record gradients to speed things up.

train()
class scnym.trainer.SemiSupervisedTrainer(unsup_criterion, unsup_dataloader, unsup_weight, dan_criterion=None, dan_weight=None, **kwargs)

Bases: scnym.trainer.Trainer

train_epoch()

Perform training using both a supervised and semi-supervised loss.

Notes

  1. Sample labeled examples, compute the standard supervised loss.

  2. Sample unlabeled examples, compute unsupervised loss.

  3. Perform backward pass and update parameters.

Return type

None

scnym.trainer.get_class_weight(y)

Generate relative class weights based on the representation of classes in a label vector y

Parameters

y (np.ndarray) – [N,] vector of class labels.

Returns

class_weight – [Classes,] vector of loss weight coefficients. if classes are str, returns weights in lexographically sorted order.

Return type

np.ndarray

scnym.trainer.cross_entropy(pred_, label, class_weight=None, sample_weight=None, reduction='mean')

Compute cross entropy loss for prediction outputs and potentially non-binary targets.

Parameters
  • pred_ (torch.FloatTensor) – [Batch, C] model outputs.

  • label (torch.FloatTensor) – [Batch, C] labels. may not necessarily be one-hot, but must satisfy simplex criterion.

  • class_weight (torch.FloatTensor) – [C,] relative weights for each of the output classes. useful for increasing attention to underrepresented classes.

  • reduction (str) – reduction method across the batch.

Returns

loss – mean cross-entropy loss across the batch indices.

Return type

torch.FloatTensor

Notes

Crossentropy is defined as:

\[H(P, Q) = -\Sum_{k \in K} P(k) log(Q(k))\]

where P, Q are discrete probability distributions defined with a common support K.

References

See for class weight computation: https://pytorch.org/docs/stable/nn.html#crossentropyloss

class scnym.trainer.InterpolationConsistencyLoss(unsup_criterion, sup_criterion, decay_coef=0.9997, mean_teacher=True, augment=None, teacher_eval=True, teacher_bn_running_stats=None, **kwargs)

Bases: object

scnym.trainer.sharpen_labels(q, T=0.5)

Reduce the entropy of a categorical label using a temperature adjustment

Parameters
  • q (torch.FloatTensor) – [N, C] pseudolabel.

  • T (float) – temperature parameter.

Returns

q_s – [C,] sharpened pseudolabel.

Return type

torch.FloatTensor

Notes

\[S(q, T) = q_i^{1/T} / \sum_j^L q_j^{1/T}\]
class scnym.trainer.MixMatchLoss(n_augmentations=2, T=0.5, augment_pseudolabels=True, pseudolabel_min_confidence=0.0, **kwargs)

Bases: scnym.trainer.InterpolationConsistencyLoss

Compute the MixMatch Loss given a batch of labeled and unlabeled examples.

n_augmentations

number of augmentated samples to average across when computing pseudolabels. default = 2 from MixMatch paper.

Type

int

T

temperature parameter.

Type

float

augment_pseudolabels

perform augmentations during pseudolabel generation.

Type

bool

pseudolabel_min_confidence

minimum confidence to compute a loss for a given pseudolabeled example. examples below this confidence threshold will be given 0 loss. see the FixMatch paper for discussion.

Type

float

teacher

teacher model for pseudolabeling.

Type

nn.Module

running_confidence_scores

[n_batches_to_store,] (torch.Tensor, torch.Tensor,) of unlabeled example (Confident_Bool, BestConfidenceScore) tuples.

Type

list

n_batches_to_store

determines how many batches to keep in running_confidence_scores.

Type

int

class scnym.trainer.DANLoss(dan_criterion, model, use_conf_pseudolabels=False, scale_loss_pseudoconf=False)

Bases: object

Compute a domain adaptation network (DAN) loss.

class scnym.trainer.ICLWeight(ramp_epochs, burn_in_epochs=0, max_unsup_weight=10.0, sigmoid=False)

Bases: object