Module Documentation

Import scNym as:

import scnym

Interactive API: api

scNym provides a simple Python API that serves as the primary endpoint for users. This API should be the first stop for users looking to apply scNym to new problems.

Classify cell identities using scNym

scnym_api() is the main API endpoint for users. This function allows for training and prediction using scnym_train() and scnym_predict(). Both of these functions will be infrequently accessed by users.

get_pretrained_weights() is a wrapper function that downloads pretrained weights from our cloud storage bucket. atlas2target() downloads preprocessed reference datasets and concatenates them onto a user supplied target dataset.

scnym.api.scnym_api(adata, task='train', groupby=None, out_path='./scnym_outputs', trained_model=None, config='new_identity_discovery', key_added='scNym', copy=False)

scNym: Semi-supervised adversarial neural networks for single cell classification [Kimmel2020].

scNym is a cell identity classifier that transfers annotations from one single cell experiment to another. The model is implemented as a neural network that employs MixMatch semi-supervision and a domain adversary to take advantage of unlabeled data during training. scNym offers superior performance to many baseline single cell identity classification methods.

Parameters
  • adata (AnnData) – Annotated data matrix used for training or prediction.

  • task (str) – Task to perform, either “train” or “predict”. If “train”, uses adata as labeled training data. If “predict”, uses trained_model to infer cell identities for observations in adata.

  • groupby (Optional[str]) – Column in adata.obs that contains cell identity annotations. Values of “Unlabeled” indicate that a given cell should be used only as unlabeled data during training.

  • out_path (str) – Path to a directory for saving scNym model weights and training logs.

  • trained_model (Optional[str]) – Used when task==”predict”’. Path to the output directory of an scNym training run or a string specifying a pretrained model. Pretrained model strings are f”pretrained_{species}” where species is one of `{“human”, “mouse”, “rat”}. Providing a pretrained model string will download pre-trained weights and predict directly on the target data, without additional training.

  • config (Union[dict, str]) –

    Configuration name or dictionary of configuration of parameters. Pre-defined configurations:

    ”new_identity_discovery” - Default. Employs pseudolabel thresholding to allow for discovery of new cell identities in the target dataset using scNym confidence scores. “no_new_identity” - Assumes all cells in the target data belong to one of the classes in the training data. Recommended to improve performance when this assumption is valid.

  • key_added (str) – Key added to adata.obs with scNym predictions if task==”predict”.

  • copy (bool) – copy the AnnData object before predicting cell types.

Return type

Optional[AnnData]

Returns

  • Depending on copy, returns or updates adata with the following fields.

  • `X_scnym` (ndarray, (obsm, shape=(n_samples, n_hidden), dtype float)) – scNym embedding coordinates of data.

  • `scNym` ((adata.obs, dtype str)) – scNym cell identity predictions for each observation.

  • `scNym_train_results` (dict, (uns)) – results of scNym model training.

Examples

>>> import scanpy as sc
>>> from scnym.api import scnym_api, atlas2target

Loading Data and preparing labels

>>> adata = sc.datasets.kang17()
>>> target_bidx = adata.obs['stim']=='stim'
>>> adata.obs['cell'] = np.array(adata.obs['cell'])
>>> adata.obs.loc[target_bidx, 'cell'] = 'Unlabeled'

Train an scNym model

>>> scnym_api(
...   adata=adata,
...   task='train',
...   groupby='clusters',
...   out_path='./scnym_outputs',
...   config='no_new_identity',
... )

Predict cell identities with the trained scNym model

>>> path_to_model = './scnym_outputs/'
>>> scnym_api(
...   adata=adata,
...   task='predict',
...   groupby='scNym',
...   trained_model=path_to_model,
...   config='no_new_identity',
... )

Predict cell identities with a pretrained scNym model

>>> scnym_api(
...   adata=adata,
...   task='predict',
...   groupby='scNym',
...   trained_model='pretrained_human',
...   config='no_new_identity',
... )

Perform semi-supervised training with an atlas

>>> joint_adata = atlas2target(
...   adata=adata,
...   species='human',
...   key_added='annotations',
... )
>>> scnym_api(
...   adata=joint_adata,
...   task='train',
...   groupby='annotations',
...   out_path='./scnym_outputs',
...   config='no_new_identity',
... )
scnym.api.scnym_train(adata, config)

Train an scNym model.

Parameters
  • adata (AnnData) – [Cells, Genes] experiment containing annotated cells to train on.

  • config (dict) – configuration options.

Return type

None

Returns

  • None.

  • Saves model outputs to config[“out_path”] and adds model results

  • to adata.uns[“scnym_train_results”].

Notes

This method should only be directly called by advanced users. Most users should use scnym_api.

See also

scnym_api()

scnym.api.scnym_predict(adata, config)

Predict cell identities using an scNym model.

Parameters
  • adata (AnnData) – [Cells, Genes] experiment containing annotated cells to train on.

  • config (dict) – configuration options.

Returns

Return type

None. Adds adata.obs[config[“key_added”]] and adata.obsm[“X_scnym”].

Notes

This method should only be directly called by advanced users. Most users should use scnym_api.

See also

scnym_api()

scnym.api.get_pretrained_weights(trained_model, out_path)

Given the name of a set of pretrained model weights, fetch weights from GCS and return the model state dict.

Parameters
  • trained_model (str) – the name of a pretrained model to use, formatted as “pretrained_{species}”. species should be one of {“human”, “mouse”, “rat”}.

  • out_path (str) – path for saving model weights and outputs.

Return type

str

Returns

  • species (str) – species parsed from the trained model name.

  • Saves “{out_path}/00_best_model_weights.pkl” and

  • ”{out_path}/scnym_train_results.pkl”.

Notes

Requires an internet connection to download pre-trained weights.

scnym.api.atlas2target(adata, species, key_added='annotations')

Download a preprocessed cell atlas dataset and append your new dataset as a target to allow for semi-supervised scNym training.

Parameters

adata (anndata.AnnData) – [Cells, Features] experiment to use as a target dataset.

Returns

joint_adata – [Cells, Features] experiment concatenated with a preprocessed cell atlas reference dataset. Annotations from the atlas are copied to .obs[key_added] and all cells in the target dataset adata are labeled with the special “Unlabeled” token.

Return type

anndata.AnnData

Examples

>>> adata = sc.datasets.pbmc3k()
>>> joint_adata = scnym.api.atlas2target(
...     adata=adata,
...     species='human',
...     key_added='annotations',
... )

Notes

Requires an internet connection to download reference datasets.

Advanced Interface

For users interested in exploring new research ideas with the scNym framework, we provide direct access to our underlying infrastructure. The modules below should be used by researchers looking to expand on the scNym approach.

Model Specification

class scnym.model.ResBlock(n_inputs, n_hidden)

Bases: torch.nn.modules.module.Module

Residual block.

References

Deep Residual Learning for Image Recognition Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun arXiv:1512.03385

forward(x)

Residual block forward pass.

Parameters

x (torch.FloatTensor) – [Batch, self.n_inputs]

Returns

o – [Batch, self.n_hidden]

Return type

torch.FloatTensor

class scnym.model.CellTypeCLF(n_genes, n_cell_types, n_hidden=256, n_layers=2, init_dropout=0.0, residual=False, batch_norm=True, track_running_stats=True)

Bases: torch.nn.modules.module.Module

Cell type classifier from expression data.

n_genes

number of input genes in the model.

Type

int

n_cell_types

number of output classes in the model.

Type

int

n_hidden

number of hidden units in the model.

Type

int

n_layers

number of hidden layers in the model.

Type

int

init_dropout

dropout proportion prior to the first layer.

Type

float

residual

use residual connections.

Type

bool

forward(x)

Perform a forward pass through the model

Parameters

x (torch.FloatTensor) – [Batch, self.n_genes]

Returns

pred – [Batch, self.n_cell_types]

Return type

torch.FloatTensor

class scnym.model.GradReverse

Bases: torch.autograd.function.Function

Layer that reverses and scales gradients before passing them up to earlier ops in the computation graph during backpropogation.

static forward(ctx, x, weight)

Perform a no-op forward pass that stores a weight for later gradient scaling during backprop.

Parameters
  • x (torch.FloatTensor) – [Batch, Features]

  • weight (float) – weight for scaling gradients during backpropogation. stored in the “context” ctx variable.

Notes

We subclass Function and use only @staticmethod as specified in the newstyle pytorch autograd functions. https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function

We define a “context” ctx of the class that will hold any values passed during forward for use in the backward pass.

x.view_as(x) and *1 are necessary so that GradReverse is actually called torch.autograd tries to optimize backprop and excludes no-ops, so we have to trick it :)

static backward(ctx, grad_output)

Return gradients

Returns

  • rev_grad (torch.FloatTensor) – reversed gradients scaled by weight passed in .forward()

  • None (None) – a dummy “gradient” required since we passed a weight float in .forward().

class scnym.model.DANN(model, n_domains=2, weight=1.0, n_layers=1)

Bases: torch.nn.modules.module.Module

Build a domain adaptation neural network

set_rev_grad_weight(weight)

Set the weight term used after reversing gradients

Return type

None

forward(x)

Perform a forward pass.

Parameters

x (torch.FloatTensor) – [Batch, Features] input.

Return type

(<class ‘torch.FloatTensor’>, <class ‘torch.FloatTensor’>)

Returns

  • domain_pred (torch.FloatTensor) – [Batch, n_domains] logits.

  • x_embed (torch.FloatTensor) – [Batch, n_hidden]

class scnym.model.CellTypeCLFConditional(n_genes, n_tissues, **kwargs)

Bases: scnym.model.CellTypeCLF

Conditional vartiaton of the CellTypeCLF

n_genes

number of the input features corresponding to genes.

Type

int

n_tissues

length of the one-hot upper_group vector appended to inputs.

Type

int

forward(x)

Perform a forward pass through the model

Parameters

x (torch.FloatTensor) – [Batch, self.n_genes + self.n_tissues]

Returns

pred – [Batch, self.n_cell_types]

Return type

torch.FloatTensor

Model Trainer

The :trainer: Module provides classes for training neural network models to classify single cells.

class scnym.trainer.Trainer(model, criterion, optimizer, dataloaders, out_path, batch_transformers={}, n_epochs=50, min_epochs=0, exp_name='', reg_criterion=None, use_gpu=False, verbose=False, save_freq=10, scheduler=None, tb_writer=None)

Bases: object

Trains a PyTorch model.

model

model with required .forward(…) method.

Type

nn.Module

criterion

loss criterion to optimize.

Type

Callable

optimizer

optimizer for the model parameters.

Type

torch.optim.Optimizer

dataloaders

keyed by [‘train’, ‘val’] with values corresponding to torch.utils.data.DataLoader for training and validation sets.

Type

dict

out_path

output path for best model.

Type

str

n_epochs

number of epochs for training.

Type

int

reg_criterion

criterion to penalize layer weights.

Type

Callable

use_gpu

use CUDA acceleration.

Type

bool

verbose

write all batch losses to stdout.

Type

bool

save_freq

Number of epochs between model checkpoints. Default = 10.

Type

int

scheduler
Type

learning rate scheduler.

train_epoch()

Perform training across one full iteration through the data.

val_epoch()

Perform a pass through the validation data. Do not record gradients to speed things up.

train()
class scnym.trainer.SemiSupervisedTrainer(unsup_criterion, unsup_dataloader, unsup_weight, dan_criterion=None, dan_weight=None, **kwargs)

Bases: scnym.trainer.Trainer

train_epoch()

Perform training using both a supervised and semi-supervised loss.

Notes

  1. Sample labeled examples, compute the standard supervised loss.

  2. Sample unlabeled examples, compute unsupervised loss.

  3. Perform backward pass and update parameters.

Return type

None

scnym.trainer.get_class_weight(y)

Generate relative class weights based on the representation of classes in a label vector y

Parameters

y (np.ndarray) – [N,] vector of class labels.

Returns

class_weight – [Classes,] vector of loss weight coefficients. if classes are str, returns weights in lexographically sorted order.

Return type

np.ndarray

scnym.trainer.cross_entropy(pred_, label, class_weight=None, sample_weight=None, reduction='mean')

Compute cross entropy loss for prediction outputs and potentially non-binary targets.

Parameters
  • pred_ (torch.FloatTensor) – [Batch, C] model outputs.

  • label (torch.FloatTensor) – [Batch, C] labels. may not necessarily be one-hot, but must satisfy simplex criterion.

  • class_weight (torch.FloatTensor) – [C,] relative weights for each of the output classes. useful for increasing attention to underrepresented classes.

  • reduction (str) – reduction method across the batch.

Returns

loss – mean cross-entropy loss across the batch indices.

Return type

torch.FloatTensor

Notes

Crossentropy is defined as:

\[H(P, Q) = -\Sum_{k \in K} P(k) log(Q(k))\]

where P, Q are discrete probability distributions defined with a common support K.

References

See for class weight computation: https://pytorch.org/docs/stable/nn.html#crossentropyloss

class scnym.trainer.InterpolationConsistencyLoss(unsup_criterion, sup_criterion, decay_coef=0.9997, mean_teacher=True, augment=None, teacher_eval=True, teacher_bn_running_stats=None, **kwargs)

Bases: object

scnym.trainer.sharpen_labels(q, T=0.5)

Reduce the entropy of a categorical label using a temperature adjustment

Parameters
  • q (torch.FloatTensor) – [N, C] pseudolabel.

  • T (float) – temperature parameter.

Returns

q_s – [C,] sharpened pseudolabel.

Return type

torch.FloatTensor

Notes

\[S(q, T) = q_i^{1/T} / \sum_j^L q_j^{1/T}\]
class scnym.trainer.MixMatchLoss(n_augmentations=2, T=0.5, augment_pseudolabels=True, pseudolabel_min_confidence=0.0, **kwargs)

Bases: scnym.trainer.InterpolationConsistencyLoss

Compute the MixMatch Loss given a batch of labeled and unlabeled examples.

n_augmentations

number of augmentated samples to average across when computing pseudolabels. default = 2 from MixMatch paper.

Type

int

T

temperature parameter.

Type

float

augment_pseudolabels

perform augmentations during pseudolabel generation.

Type

bool

pseudolabel_min_confidence

minimum confidence to compute a loss for a given pseudolabeled example. examples below this confidence threshold will be given 0 loss. see the FixMatch paper for discussion.

Type

float

teacher

teacher model for pseudolabeling.

Type

nn.Module

running_confidence_scores

[n_batches_to_store,] (torch.Tensor, torch.Tensor,) of unlabeled example (Confident_Bool, BestConfidenceScore) tuples.

Type

list

n_batches_to_store

determines how many batches to keep in running_confidence_scores.

Type

int

class scnym.trainer.DANLoss(dan_criterion, model, use_conf_pseudolabels=False, scale_loss_pseudoconf=False)

Bases: object

Compute a domain adaptation network (DAN) loss.

class scnym.trainer.ICLWeight(ramp_epochs, burn_in_epochs=0, max_unsup_weight=10.0, sigmoid=False)

Bases: object

Model Interpretation

The :interpret: module provides saliency tools for interpreting model decisions.

Tools for interpreting trained scNym models

class scnym.interpret.Salience(model, class_names, gene_names=None, layer_to_hook=None, verbose=False)

Bases: object

Performs backpropogation to compute gradients on a target class with regards to an input.

Saliency analysis computes a gradient on a target class score \(f_i(x)\) with regards to some input \(x\).

\[S_i =\]

rac{partial f_i(x)}{partial x}

get_saliency(x, target_class, guide_backprop=False)

Compute the saliency of a target class on an input vector x.

Parameters
  • x (torch.FloatTensor) – [1, Genes] vector of gene expression.

  • target_class (str) – class in .class_names for which to compute gradients.

  • guide_backprop (bool) – perform “guided backpropogation” by clamping gradients to only positive values at each ReLU. see: https://arxiv.org/pdf/1412.6806.pdf

Returns

salience – gradients on target_class with respect to x.

Return type

torch.FloatTensor

rank_genes_by_saliency(**kwargs)

Rank genes by saliency for a target class and input.

Passes **kwargs to .get_saliency and uses the output to rank genes.

Returns

ranked_genes – gene names with high saliency, ranked highest to lowest.

Return type

np.ndarray

Data Loaders

The :dataprep: module provides tools for loading and augmenting single cell data.

class scnym.dataprep.SingleCellDS(X, y, transform=None, num_classes=- 1)

Bases: torch.utils.data.dataset.Dataset

Dataset class for loading single cell profiles.

X

[Cells, Genes] cell profiles.

Type

np.ndarray, sparse.csr_matrix

y_labels

[Cells,] integer class labels.

Type

np.ndarray, sparse.csr_matrix

y

[Cells, Classes] one hot labels.

Type

torch.FloatTensor

transform

performs data transformation operations on a sample dict.

Type

Callable

num_classes

number of classes in the dataset. default -1 infers the number of classes as len(unique(y)).

Type

int

scnym.dataprep.balance_classes(y, class_min=256)

Perform class balancing by undersampling majority classes and oversampling minority classes, down to a minimum value.

Parameters
  • y (np.ndarray) – class assignment indices.

  • class_min (int) – minimum number of examples to use for a class. below this value, minority classes will be oversampled with replacement.

Returns

all_idx – indices for balanced classes. some indices may be repeated.

Return type

np.ndarray

class scnym.dataprep.LibrarySizeNormalize(counts_per_cell_after=1000000, log1p=True)

Bases: object

Perform library size normalization.

class scnym.dataprep.ExpMinusOne

Bases: object

class scnym.dataprep.MultinomialSample(depth=10000, 100000, depth_ratio=None)

Bases: object

Sample an mRNA abundance profile from a multinomial distribution parameterized by observations.

class scnym.dataprep.GeneMasking(p_drop=0.1, p_apply=0.5, sample_p_drop=False)

Bases: object

class scnym.dataprep.InputDropout(p_drop=0.1)

Bases: object

class scnym.dataprep.PoissonSample

Bases: object

Sample a gene expression profile based on gene-specific Poisson distributions

scnym.dataprep.mixup(a, b, gamma)

Perform a MixUp operation. This is effectively just a weighted average, where gamma = 0.5 yields the mean of a and b.

Parameters
  • a (torch.FloatTensor) – [Batch, C] first sample matrix.

  • b (torch.FloatTensor) – [Batch, C] second sample matrix.

  • gamma (torch.FloatTensor) – [Batch,] MixUp coefficient.

Returns

m – [Batch, C] mixed sample matrix.

Return type

torch.FloatTensor

class scnym.dataprep.SampleMixUp(alpha=0.2, keep_dominant_obs=False)

Bases: object

scnym.dataprep.identity(x)

Identity function

Return type

Any