scNym/trainer Module¶
-
class
scnym.trainer.
Trainer
(model, criterion, optimizer, dataloaders, out_path, batch_transformers={}, n_epochs=50, min_epochs=0, exp_name='', reg_criterion=None, use_gpu=False, verbose=False, save_freq=10, scheduler=None, tb_writer=None)¶ Bases:
object
Trains a PyTorch model.
-
model
¶ model with required .forward(…) method.
- Type
nn.Module
-
criterion
¶ loss criterion to optimize.
- Type
Callable
-
optimizer
¶ optimizer for the model parameters.
- Type
torch.optim.Optimizer
-
dataloaders
¶ keyed by [‘train’, ‘val’] with values corresponding to torch.utils.data.DataLoader for training and validation sets.
- Type
dict
-
out_path
¶ output path for best model.
- Type
str
-
n_epochs
¶ number of epochs for training.
- Type
int
-
reg_criterion
¶ criterion to penalize layer weights.
- Type
Callable
-
use_gpu
¶ use CUDA acceleration.
- Type
bool
-
verbose
¶ write all batch losses to stdout.
- Type
bool
-
save_freq
¶ Number of epochs between model checkpoints. Default = 10.
- Type
int
-
scheduler
¶ - Type
learning rate scheduler.
-
train_epoch
()¶ Perform training across one full iteration through the data.
-
val_epoch
()¶ Perform a pass through the validation data. Do not record gradients to speed things up.
-
train
()¶
-
-
class
scnym.trainer.
SemiSupervisedTrainer
(unsup_criterion, unsup_dataloader, unsup_weight, dan_criterion=None, dan_weight=None, **kwargs)¶ Bases:
scnym.trainer.Trainer
-
train_epoch
()¶ Perform training using both a supervised and semi-supervised loss.
Notes
Sample labeled examples, compute the standard supervised loss.
Sample unlabeled examples, compute unsupervised loss.
Perform backward pass and update parameters.
- Return type
None
-
-
scnym.trainer.
get_class_weight
(y)¶ Generate relative class weights based on the representation of classes in a label vector y
- Parameters
y (np.ndarray) – [N,] vector of class labels.
- Returns
class_weight – [Classes,] vector of loss weight coefficients. if classes are str, returns weights in lexographically sorted order.
- Return type
np.ndarray
-
scnym.trainer.
cross_entropy
(pred_, label, class_weight=None, sample_weight=None, reduction='mean')¶ Compute cross entropy loss for prediction outputs and potentially non-binary targets.
- Parameters
pred_ (torch.FloatTensor) – [Batch, C] model outputs.
label (torch.FloatTensor) – [Batch, C] labels. may not necessarily be one-hot, but must satisfy simplex criterion.
class_weight (torch.FloatTensor) – [C,] relative weights for each of the output classes. useful for increasing attention to underrepresented classes.
reduction (str) – reduction method across the batch.
- Returns
loss – mean cross-entropy loss across the batch indices.
- Return type
torch.FloatTensor
Notes
Crossentropy is defined as:
\[H(P, Q) = -\Sum_{k \in K} P(k) log(Q(k))\]where P, Q are discrete probability distributions defined with a common support K.
References
See for class weight computation: https://pytorch.org/docs/stable/nn.html#crossentropyloss
-
class
scnym.trainer.
InterpolationConsistencyLoss
(unsup_criterion, sup_criterion, decay_coef=0.9997, mean_teacher=True, augment=None, teacher_eval=True, teacher_bn_running_stats=None, **kwargs)¶ Bases:
object
-
scnym.trainer.
sharpen_labels
(q, T=0.5)¶ Reduce the entropy of a categorical label using a temperature adjustment
- Parameters
q (torch.FloatTensor) – [N, C] pseudolabel.
T (float) – temperature parameter.
- Returns
q_s – [C,] sharpened pseudolabel.
- Return type
torch.FloatTensor
Notes
\[S(q, T) = q_i^{1/T} / \sum_j^L q_j^{1/T}\]
-
class
scnym.trainer.
MixMatchLoss
(n_augmentations=2, T=0.5, augment_pseudolabels=True, pseudolabel_min_confidence=0.0, **kwargs)¶ Bases:
scnym.trainer.InterpolationConsistencyLoss
Compute the MixMatch Loss given a batch of labeled and unlabeled examples.
-
n_augmentations
¶ number of augmentated samples to average across when computing pseudolabels. default = 2 from MixMatch paper.
- Type
int
-
T
¶ temperature parameter.
- Type
float
-
augment_pseudolabels
¶ perform augmentations during pseudolabel generation.
- Type
bool
-
pseudolabel_min_confidence
¶ minimum confidence to compute a loss for a given pseudolabeled example. examples below this confidence threshold will be given 0 loss. see the FixMatch paper for discussion.
- Type
float
-
teacher
¶ teacher model for pseudolabeling.
- Type
nn.Module
-
running_confidence_scores
¶ [n_batches_to_store,] (torch.Tensor, torch.Tensor,) of unlabeled example (Confident_Bool, BestConfidenceScore) tuples.
- Type
list
-
n_batches_to_store
¶ determines how many batches to keep in running_confidence_scores.
- Type
int
-
-
class
scnym.trainer.
DANLoss
(dan_criterion, model, use_conf_pseudolabels=False, scale_loss_pseudoconf=False)¶ Bases:
object
Compute a domain adaptation network (DAN) loss.
-
class
scnym.trainer.
ICLWeight
(ramp_epochs, burn_in_epochs=0, max_unsup_weight=10.0, sigmoid=False)¶ Bases:
object