Module Documentation¶
Import scNym as:
import scnym
Interactive API: api¶
scNym provides a simple Python API that serves as the primary endpoint for users. This API should be the first stop for users looking to apply scNym to new problems.
Classify cell identities using scNym
scnym_api() is the main API endpoint for users. This function allows for training and prediction using scnym_train() and scnym_predict(). Both of these functions will be infrequently accessed by users.
get_pretrained_weights() is a wrapper function that downloads pretrained weights from our cloud storage bucket. atlas2target() downloads preprocessed reference datasets and concatenates them onto a user supplied target dataset.
-
scnym.api.
scnym_api
(adata, task='train', groupby=None, out_path='./scnym_outputs', trained_model=None, config='new_identity_discovery', key_added='scNym', copy=False)¶ scNym: Semi-supervised adversarial neural networks for single cell classification [Kimmel2020].
scNym is a cell identity classifier that transfers annotations from one single cell experiment to another. The model is implemented as a neural network that employs MixMatch semi-supervision and a domain adversary to take advantage of unlabeled data during training. scNym offers superior performance to many baseline single cell identity classification methods.
- Parameters
adata (
AnnData
) – Annotated data matrix used for training or prediction.task (
str
) – Task to perform, either “train” or “predict”. If “train”, uses adata as labeled training data. If “predict”, uses trained_model to infer cell identities for observations in adata.groupby (
Optional
[str
]) – Column in adata.obs that contains cell identity annotations. Values of “Unlabeled” indicate that a given cell should be used only as unlabeled data during training.out_path (
str
) – Path to a directory for saving scNym model weights and training logs.trained_model (
Optional
[str
]) – Used when task==”predict”’. Path to the output directory of an scNym training run or a string specifying a pretrained model. Pretrained model strings are f”pretrained_{species}” where species is one of `{“human”, “mouse”, “rat”}. Providing a pretrained model string will download pre-trained weights and predict directly on the target data, without additional training.config (
Union
[dict
,str
]) –Configuration name or dictionary of configuration of parameters. Pre-defined configurations:
”new_identity_discovery” - Default. Employs pseudolabel thresholding to allow for discovery of new cell identities in the target dataset using scNym confidence scores. “no_new_identity” - Assumes all cells in the target data belong to one of the classes in the training data. Recommended to improve performance when this assumption is valid.
key_added (
str
) – Key added to adata.obs with scNym predictions if task==”predict”.copy (
bool
) – copy the AnnData object before predicting cell types.
- Return type
Optional
[AnnData
]- Returns
Depending on copy, returns or updates adata with the following fields.
`X_scnym` (
ndarray
, (obsm
, shape=(n_samples, n_hidden), dtype float)) – scNym embedding coordinates of data.`scNym` ((adata.obs, dtype str)) – scNym cell identity predictions for each observation.
`scNym_train_results` (
dict
, (uns
)) – results of scNym model training.
Examples
>>> import scanpy as sc >>> from scnym.api import scnym_api, atlas2target
Loading Data and preparing labels
>>> adata = sc.datasets.kang17() >>> target_bidx = adata.obs['stim']=='stim' >>> adata.obs['cell'] = np.array(adata.obs['cell']) >>> adata.obs.loc[target_bidx, 'cell'] = 'Unlabeled'
Train an scNym model
>>> scnym_api( ... adata=adata, ... task='train', ... groupby='clusters', ... out_path='./scnym_outputs', ... config='no_new_identity', ... )
Predict cell identities with the trained scNym model
>>> path_to_model = './scnym_outputs/' >>> scnym_api( ... adata=adata, ... task='predict', ... groupby='scNym', ... trained_model=path_to_model, ... config='no_new_identity', ... )
Predict cell identities with a pretrained scNym model
>>> scnym_api( ... adata=adata, ... task='predict', ... groupby='scNym', ... trained_model='pretrained_human', ... config='no_new_identity', ... )
Perform semi-supervised training with an atlas
>>> joint_adata = atlas2target( ... adata=adata, ... species='human', ... key_added='annotations', ... ) >>> scnym_api( ... adata=joint_adata, ... task='train', ... groupby='annotations', ... out_path='./scnym_outputs', ... config='no_new_identity', ... )
-
scnym.api.
scnym_train
(adata, config)¶ Train an scNym model.
- Parameters
adata (AnnData) – [Cells, Genes] experiment containing annotated cells to train on.
config (dict) – configuration options.
- Return type
None
- Returns
None.
Saves model outputs to config[“out_path”] and adds model results
to adata.uns[“scnym_train_results”].
Notes
This method should only be directly called by advanced users. Most users should use scnym_api.
See also
-
scnym.api.
scnym_predict
(adata, config)¶ Predict cell identities using an scNym model.
- Parameters
adata (AnnData) – [Cells, Genes] experiment containing annotated cells to train on.
config (dict) – configuration options.
- Returns
- Return type
None. Adds adata.obs[config[“key_added”]] and adata.obsm[“X_scnym”].
Notes
This method should only be directly called by advanced users. Most users should use scnym_api.
See also
-
scnym.api.
get_pretrained_weights
(trained_model, out_path)¶ Given the name of a set of pretrained model weights, fetch weights from GCS and return the model state dict.
- Parameters
trained_model (str) – the name of a pretrained model to use, formatted as “pretrained_{species}”. species should be one of {“human”, “mouse”, “rat”}.
out_path (str) – path for saving model weights and outputs.
- Return type
str
- Returns
species (str) – species parsed from the trained model name.
Saves “{out_path}/00_best_model_weights.pkl” and
”{out_path}/scnym_train_results.pkl”.
Notes
Requires an internet connection to download pre-trained weights.
-
scnym.api.
atlas2target
(adata, species, key_added='annotations')¶ Download a preprocessed cell atlas dataset and append your new dataset as a target to allow for semi-supervised scNym training.
- Parameters
adata (anndata.AnnData) – [Cells, Features] experiment to use as a target dataset.
- Returns
joint_adata – [Cells, Features] experiment concatenated with a preprocessed cell atlas reference dataset. Annotations from the atlas are copied to .obs[key_added] and all cells in the target dataset adata are labeled with the special “Unlabeled” token.
- Return type
anndata.AnnData
Examples
>>> adata = sc.datasets.pbmc3k() >>> joint_adata = scnym.api.atlas2target( ... adata=adata, ... species='human', ... key_added='annotations', ... )
Notes
Requires an internet connection to download reference datasets.
Advanced Interface¶
For users interested in exploring new research ideas with the scNym framework, we provide direct access to our underlying infrastructure. The modules below should be used by researchers looking to expand on the scNym approach.
Model Specification¶
-
class
scnym.model.
ResBlock
(n_inputs, n_hidden)¶ Bases:
torch.nn.modules.module.Module
Residual block.
References
Deep Residual Learning for Image Recognition Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun arXiv:1512.03385
-
forward
(x)¶ Residual block forward pass.
- Parameters
x (torch.FloatTensor) – [Batch, self.n_inputs]
- Returns
o – [Batch, self.n_hidden]
- Return type
torch.FloatTensor
-
-
class
scnym.model.
CellTypeCLF
(n_genes, n_cell_types, n_hidden=256, n_layers=2, init_dropout=0.0, residual=False, batch_norm=True, track_running_stats=True)¶ Bases:
torch.nn.modules.module.Module
Cell type classifier from expression data.
-
n_genes
¶ number of input genes in the model.
- Type
int
-
n_cell_types
¶ number of output classes in the model.
- Type
int
number of hidden units in the model.
- Type
int
-
n_layers
¶ number of hidden layers in the model.
- Type
int
-
init_dropout
¶ dropout proportion prior to the first layer.
- Type
float
-
residual
¶ use residual connections.
- Type
bool
-
forward
(x)¶ Perform a forward pass through the model
- Parameters
x (torch.FloatTensor) – [Batch, self.n_genes]
- Returns
pred – [Batch, self.n_cell_types]
- Return type
torch.FloatTensor
-
-
class
scnym.model.
GradReverse
¶ Bases:
torch.autograd.function.Function
Layer that reverses and scales gradients before passing them up to earlier ops in the computation graph during backpropogation.
-
static
forward
(ctx, x, weight)¶ Perform a no-op forward pass that stores a weight for later gradient scaling during backprop.
- Parameters
x (torch.FloatTensor) – [Batch, Features]
weight (float) – weight for scaling gradients during backpropogation. stored in the “context” ctx variable.
Notes
We subclass Function and use only @staticmethod as specified in the newstyle pytorch autograd functions. https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
We define a “context” ctx of the class that will hold any values passed during forward for use in the backward pass.
x.view_as(x) and *1 are necessary so that GradReverse is actually called torch.autograd tries to optimize backprop and excludes no-ops, so we have to trick it :)
-
static
backward
(ctx, grad_output)¶ Return gradients
- Returns
rev_grad (torch.FloatTensor) – reversed gradients scaled by weight passed in .forward()
None (None) – a dummy “gradient” required since we passed a weight float in .forward().
-
static
-
class
scnym.model.
DANN
(model, n_domains=2, weight=1.0, n_layers=1)¶ Bases:
torch.nn.modules.module.Module
Build a domain adaptation neural network
-
set_rev_grad_weight
(weight)¶ Set the weight term used after reversing gradients
- Return type
None
-
forward
(x)¶ Perform a forward pass.
- Parameters
x (torch.FloatTensor) – [Batch, Features] input.
- Return type
(<class ‘torch.FloatTensor’>, <class ‘torch.FloatTensor’>)
- Returns
domain_pred (torch.FloatTensor) – [Batch, n_domains] logits.
x_embed (torch.FloatTensor) – [Batch, n_hidden]
-
-
class
scnym.model.
CellTypeCLFConditional
(n_genes, n_tissues, **kwargs)¶ Bases:
scnym.model.CellTypeCLF
Conditional vartiaton of the CellTypeCLF
-
n_genes
¶ number of the input features corresponding to genes.
- Type
int
-
n_tissues
¶ length of the one-hot upper_group vector appended to inputs.
- Type
int
-
forward
(x)¶ Perform a forward pass through the model
- Parameters
x (torch.FloatTensor) – [Batch, self.n_genes + self.n_tissues]
- Returns
pred – [Batch, self.n_cell_types]
- Return type
torch.FloatTensor
-
Model Trainer¶
The :trainer: Module provides classes for training neural network models to classify single cells.
-
class
scnym.trainer.
Trainer
(model, criterion, optimizer, dataloaders, out_path, batch_transformers={}, n_epochs=50, min_epochs=0, exp_name='', reg_criterion=None, use_gpu=False, verbose=False, save_freq=10, scheduler=None, tb_writer=None)¶ Bases:
object
Trains a PyTorch model.
-
model
¶ model with required .forward(…) method.
- Type
nn.Module
-
criterion
¶ loss criterion to optimize.
- Type
Callable
-
optimizer
¶ optimizer for the model parameters.
- Type
torch.optim.Optimizer
-
dataloaders
¶ keyed by [‘train’, ‘val’] with values corresponding to torch.utils.data.DataLoader for training and validation sets.
- Type
dict
-
out_path
¶ output path for best model.
- Type
str
-
n_epochs
¶ number of epochs for training.
- Type
int
-
reg_criterion
¶ criterion to penalize layer weights.
- Type
Callable
-
use_gpu
¶ use CUDA acceleration.
- Type
bool
-
verbose
¶ write all batch losses to stdout.
- Type
bool
-
save_freq
¶ Number of epochs between model checkpoints. Default = 10.
- Type
int
-
scheduler
¶ - Type
learning rate scheduler.
-
train_epoch
()¶ Perform training across one full iteration through the data.
-
val_epoch
()¶ Perform a pass through the validation data. Do not record gradients to speed things up.
-
train
()¶
-
-
class
scnym.trainer.
SemiSupervisedTrainer
(unsup_criterion, unsup_dataloader, unsup_weight, dan_criterion=None, dan_weight=None, **kwargs)¶ Bases:
scnym.trainer.Trainer
-
train_epoch
()¶ Perform training using both a supervised and semi-supervised loss.
Notes
Sample labeled examples, compute the standard supervised loss.
Sample unlabeled examples, compute unsupervised loss.
Perform backward pass and update parameters.
- Return type
None
-
-
scnym.trainer.
get_class_weight
(y)¶ Generate relative class weights based on the representation of classes in a label vector y
- Parameters
y (np.ndarray) – [N,] vector of class labels.
- Returns
class_weight – [Classes,] vector of loss weight coefficients. if classes are str, returns weights in lexographically sorted order.
- Return type
np.ndarray
-
scnym.trainer.
cross_entropy
(pred_, label, class_weight=None, sample_weight=None, reduction='mean')¶ Compute cross entropy loss for prediction outputs and potentially non-binary targets.
- Parameters
pred_ (torch.FloatTensor) – [Batch, C] model outputs.
label (torch.FloatTensor) – [Batch, C] labels. may not necessarily be one-hot, but must satisfy simplex criterion.
class_weight (torch.FloatTensor) – [C,] relative weights for each of the output classes. useful for increasing attention to underrepresented classes.
reduction (str) – reduction method across the batch.
- Returns
loss – mean cross-entropy loss across the batch indices.
- Return type
torch.FloatTensor
Notes
Crossentropy is defined as:
\[H(P, Q) = -\Sum_{k \in K} P(k) log(Q(k))\]where P, Q are discrete probability distributions defined with a common support K.
References
See for class weight computation: https://pytorch.org/docs/stable/nn.html#crossentropyloss
-
class
scnym.trainer.
InterpolationConsistencyLoss
(unsup_criterion, sup_criterion, decay_coef=0.9997, mean_teacher=True, augment=None, teacher_eval=True, teacher_bn_running_stats=None, **kwargs)¶ Bases:
object
-
scnym.trainer.
sharpen_labels
(q, T=0.5)¶ Reduce the entropy of a categorical label using a temperature adjustment
- Parameters
q (torch.FloatTensor) – [N, C] pseudolabel.
T (float) – temperature parameter.
- Returns
q_s – [C,] sharpened pseudolabel.
- Return type
torch.FloatTensor
Notes
\[S(q, T) = q_i^{1/T} / \sum_j^L q_j^{1/T}\]
-
class
scnym.trainer.
MixMatchLoss
(n_augmentations=2, T=0.5, augment_pseudolabels=True, pseudolabel_min_confidence=0.0, **kwargs)¶ Bases:
scnym.trainer.InterpolationConsistencyLoss
Compute the MixMatch Loss given a batch of labeled and unlabeled examples.
-
n_augmentations
¶ number of augmentated samples to average across when computing pseudolabels. default = 2 from MixMatch paper.
- Type
int
-
T
¶ temperature parameter.
- Type
float
-
augment_pseudolabels
¶ perform augmentations during pseudolabel generation.
- Type
bool
-
pseudolabel_min_confidence
¶ minimum confidence to compute a loss for a given pseudolabeled example. examples below this confidence threshold will be given 0 loss. see the FixMatch paper for discussion.
- Type
float
-
teacher
¶ teacher model for pseudolabeling.
- Type
nn.Module
-
running_confidence_scores
¶ [n_batches_to_store,] (torch.Tensor, torch.Tensor,) of unlabeled example (Confident_Bool, BestConfidenceScore) tuples.
- Type
list
-
n_batches_to_store
¶ determines how many batches to keep in running_confidence_scores.
- Type
int
-
-
class
scnym.trainer.
DANLoss
(dan_criterion, model, use_conf_pseudolabels=False, scale_loss_pseudoconf=False)¶ Bases:
object
Compute a domain adaptation network (DAN) loss.
-
class
scnym.trainer.
ICLWeight
(ramp_epochs, burn_in_epochs=0, max_unsup_weight=10.0, sigmoid=False)¶ Bases:
object
Model Interpretation¶
The :interpret: module provides saliency tools for interpreting model decisions.
Tools for interpreting trained scNym models
-
class
scnym.interpret.
Salience
(model, class_names, gene_names=None, layer_to_hook=None, verbose=False)¶ Bases:
object
Performs backpropogation to compute gradients on a target class with regards to an input.
Saliency analysis computes a gradient on a target class score \(f_i(x)\) with regards to some input \(x\).
\[S_i =\]rac{partial f_i(x)}{partial x}
-
get_saliency
(x, target_class, guide_backprop=False)¶ Compute the saliency of a target class on an input vector x.
- Parameters
x (torch.FloatTensor) – [1, Genes] vector of gene expression.
target_class (str) – class in .class_names for which to compute gradients.
guide_backprop (bool) – perform “guided backpropogation” by clamping gradients to only positive values at each ReLU. see: https://arxiv.org/pdf/1412.6806.pdf
- Returns
salience – gradients on target_class with respect to x.
- Return type
torch.FloatTensor
-
Data Loaders¶
The :dataprep: module provides tools for loading and augmenting single cell data.
-
class
scnym.dataprep.
SingleCellDS
(X, y, transform=None, num_classes=- 1)¶ Bases:
torch.utils.data.dataset.Dataset
Dataset class for loading single cell profiles.
-
X
¶ [Cells, Genes] cell profiles.
- Type
np.ndarray, sparse.csr_matrix
-
y_labels
¶ [Cells,] integer class labels.
- Type
np.ndarray, sparse.csr_matrix
-
y
¶ [Cells, Classes] one hot labels.
- Type
torch.FloatTensor
-
transform
¶ performs data transformation operations on a sample dict.
- Type
Callable
-
num_classes
¶ number of classes in the dataset. default -1 infers the number of classes as len(unique(y)).
- Type
int
-
-
scnym.dataprep.
balance_classes
(y, class_min=256)¶ Perform class balancing by undersampling majority classes and oversampling minority classes, down to a minimum value.
- Parameters
y (np.ndarray) – class assignment indices.
class_min (int) – minimum number of examples to use for a class. below this value, minority classes will be oversampled with replacement.
- Returns
all_idx – indices for balanced classes. some indices may be repeated.
- Return type
np.ndarray
-
class
scnym.dataprep.
LibrarySizeNormalize
(counts_per_cell_after=1000000, log1p=True)¶ Bases:
object
Perform library size normalization.
-
class
scnym.dataprep.
ExpMinusOne
¶ Bases:
object
-
class
scnym.dataprep.
MultinomialSample
(depth=10000, 100000, depth_ratio=None)¶ Bases:
object
Sample an mRNA abundance profile from a multinomial distribution parameterized by observations.
-
class
scnym.dataprep.
GeneMasking
(p_drop=0.1, p_apply=0.5, sample_p_drop=False)¶ Bases:
object
-
class
scnym.dataprep.
InputDropout
(p_drop=0.1)¶ Bases:
object
-
class
scnym.dataprep.
PoissonSample
¶ Bases:
object
Sample a gene expression profile based on gene-specific Poisson distributions
-
scnym.dataprep.
mixup
(a, b, gamma)¶ Perform a MixUp operation. This is effectively just a weighted average, where gamma = 0.5 yields the mean of a and b.
- Parameters
a (torch.FloatTensor) – [Batch, C] first sample matrix.
b (torch.FloatTensor) – [Batch, C] second sample matrix.
gamma (torch.FloatTensor) – [Batch,] MixUp coefficient.
- Returns
m – [Batch, C] mixed sample matrix.
- Return type
torch.FloatTensor
-
class
scnym.dataprep.
SampleMixUp
(alpha=0.2, keep_dominant_obs=False)¶ Bases:
object
-
scnym.dataprep.
identity
(x)¶ Identity function
- Return type
Any