scNym/interpret Module

Tools for interpreting trained scNym models

class scnym.interpret.Salience(model, class_names, gene_names=None, layer_to_hook=None, verbose=False)

Bases: object

Performs backpropogation to compute gradients on a target class with regards to an input.

Saliency analysis computes a gradient on a target class score \(f_i(x)\) with regards to some input \(x\).

\[S_i =\]

rac{partial f_i(x)}{partial x}

get_saliency(x, target_class, guide_backprop=False)

Compute the saliency of a target class on an input vector x.

Parameters
  • x (torch.FloatTensor) – [1, Genes] vector of gene expression.

  • target_class (str) – class in .class_names for which to compute gradients.

  • guide_backprop (bool) – perform “guided backpropogation” by clamping gradients to only positive values at each ReLU. see: https://arxiv.org/pdf/1412.6806.pdf

Returns

salience – gradients on target_class with respect to x.

Return type

torch.FloatTensor

rank_genes_by_saliency(**kwargs)

Rank genes by saliency for a target class and input.

Passes **kwargs to .get_saliency and uses the output to rank genes.

Returns

ranked_genes – gene names with high saliency, ranked highest to lowest.

Return type

np.ndarray