from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
In this tutorial I will cover how to use the new RP model API to predict gene expression and find trans-regulated genes. The CisModeler
is the entry for all methods, and implements an sklearn-style API similar to the topic models, namely fit
, score
, predict
, and get_logp
.
#kladi imports
from kladi.rp_model import CisModeler
from kladi.matrix_models.estimator import ExpressionTrainer, AccessibilityTrainer
from kladi.core.plot_utils import RAW_UMAP, IMPUTED_UMAP #scanpy umap plot formatting shortcuts
#general imports
import anndata
import scanpy as sc
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import seaborn as sns
To instantiate a CisModeler
object, we need to provide trained AccessibilityModel
and ExpressionModel
objects.
rna_model = ExpressionTrainer.load('data/mouse_prostate/dispersion_test.pth')
atac_model = AccessibilityTrainer.load('/Users/alynch/Dropbox (Partners HealthCare)/Data/mouse_prostate/best_atac_model.pth')
And we also need to load training data. To train cis-models, we must provide expression counts and accessiblity counts with shared barcodes.
joint_rna_data = anndata.read_h5ad('data/mouse_prostate/checkpoint/joint_rna_data.h5ad')
joint_atac_data = anndata.read_h5ad('data/mouse_prostate/checkpoint/joint_atac_data.h5ad')
assert(all(joint_rna_data.obs_names == joint_atac_data.obs_names)) #make sure barcodes are aligned
Lastly, when we instantiate a CisModeler
, by default it will attempt to learn a model for every gene modeled by the expression topic model. This can be time-consuming, and not every gene may be appropriate or needed for downstream analysis.
I recommend only training RP models for highly variable genes, at least to start. For this tutorial, I will demonstrate with a sample set of 40 genes.
with open('data/mouse_prostate/test_genes.txt', 'r') as f:
test_genes = [x.strip() for x in f.readlines()]
cis_model = CisModeler('mm10', #species
expression_model=rna_model,
accessibility_model=atac_model,
genes = test_genes) #to train for all variable genes: rna_model.genes[rna_model.highly_variable]
No adjacent peaks to gene BRINP3 Gene 1700021F05RIK not in RefSeq database for this species
When fitting/predicting/scoring RP models, the expression and accessibility states of the cells must be processed into features. Each modeling method (fit
, score
, predict
, and get_logp
) takes as arguments:
Processing these features is time-consuming, often taking longer than the actual modeling step itself. To mitigate this, when you supply these features to the cis_model
object, it will process and save the features.
Then, for each method call, if they are not provided, the saved features will be used. To get this step out of the way, you can use CisModeler.load_features
:
cis_model.load_features(expression = joint_rna_data[:, rna_model.genes].X, accessibility_matrix = joint_atac_data.X)
Predicting latent vars: 100%|██████████| 6/6 [00:18<00:00, 3.11s/it]
Now, if you want to subset the cells into train/test sets, you can simply pass the indices of the sets to the modeling methods, and it will subset the registered features.
In my case, I only have 1000 cells so subsetting significantly reduces my training size and the remaining test set may not be a good representation of the data, so I skip this step.
train_idx, test_idx = train_test_split(np.arange(len(joint_rna_data)), train_size=0.8)
cis_model.fit(idx=train_idx)
Next, fit the models. Each gene's model is fit using a 2nd-order LBGFS optimizer, whose learning rate is robust and does not need to be tuned as SGD might.
cis_model.fit()
Training models: 100%|██████████| 39/39 [00:50<00:00, 1.30s/it]
<kladi.rp_model.CisModeler at 0x7ff1b81fc6d0>
Saving and loading is simple, just provide a prefix and each model will be saved as <prefix>_<gene>.pth
.
To reload models. Instantiate a CisModeler
object as above, then provide the prefix:
cis_model.load('data/mouse_prostate/cis_models/')
cis_model.save('data/mouse_prostate/cis_models/')
Let's evaluate the models:
cis_model.score()
Scoring models: 100%|██████████| 39/39 [00:00<00:00, 147.12it/s]
0.28241319079443633
To make predictions on all cells:
cis_predictions = anndata.AnnData(
X = cis_model.predict(), #predict expression using cis models
var = pd.DataFrame(index = cis_model.genes), #provide genes as columns
obs = joint_atac_data.obs, #copy obs from atac data
obsm = joint_atac_data.obsm) #copy obsm from atac data
Predicting expression: 100%|██████████| 39/39 [00:00<00:00, 305.19it/s]
sc.pl.umap(cis_predictions, color = ['LEF1','TACO1'], frameon = False, palette = 'inferno')
sc.pl.umap(joint_rna_data, color = ['LEF1','TACO1'], **RAW_UMAP)
To analyze a particular genes' model, you may use the CisModeler.get_model
method.
Using the .guide()
method returns the MAP estimates of the model's paramters. Most notably, the a
paramter scales the effects of upstream/promoter/downstream accessibility,
and the logdistance
parameter shows the upstream and downstream decay distances in Kb, respectively.
cis_model.get_model('TACO1').guide()
{'cis_TACO1/a': tensor([1.0196e-02, 1.1395e+00, 2.5541e-06], grad_fn=<ExpandBackward>), 'cis_TACO1/logdistance': tensor([121.1636, 2.2378], grad_fn=<ExpandBackward>), 'cis_TACO1/theta': tensor(0.1422, grad_fn=<ExpandBackward>), 'cis_TACO1/gamma': tensor(2.0226, grad_fn=<ExpandBackward>), 'cis_TACO1/bias': tensor(0.2882, grad_fn=<ExpandBackward>)}
Next, we want to identify genes that show interesting behavior with respect to their proximal chromatin, and to categorize genes based on that relationship. To fascilitate this, we must train a second group of models, this time providing those models with access to the atac latent features, giving them a "view" of the entire cell state from which to make predictions. The performance of these "trans" models relative to the "cis" models serves as the cis/trans test.
Trans models are instantiated in the exact same way as cis models, except to differences noted below:
trans_model = CisModeler(
'mm10', expression_model=rna_model, accessibility_model=atac_model, genes =test_genes, #same as cis models
use_trans_features=True, #set to true
cis_models = cis_model, #provide pre-trained cis models as starting initialization
)
No adjacent peaks to gene BRINP3 Gene 1700021F05RIK not in RefSeq database for this species
trans_model.fit(expression = joint_rna_data[:, rna_model.genes].X, accessibility_matrix = joint_atac_data.X) #example of providing features with method
Predicting latent vars: 100%|██████████| 6/6 [00:18<00:00, 3.15s/it] Training models: 100%|██████████| 39/39 [00:42<00:00, 1.10s/it]
<kladi.rp_model.CisModeler at 0x7ff1babb73d0>
trans_model.save('data/mouse_prostate/trans_models/')
The trans models have the exact same interface and perform all of the same methods as cis models. To identify trans-regulated genes, use CisModeler.cis_trans_test
. It might be a good idea to explicity provide features to this method so that one may be sure the cis and trans models are being scored on the same cells.
(Note: you must call this method from the trans model)
results = trans_model.cis_trans_test(cis_model, #required, provide the cis model to compare predictions
expression = joint_rna_data[:, rna_model.genes].X, #features
accessibility_matrix = joint_atac_data.X, #features
#idx = test_idx --> provide test_idx if applicable
)
Predicting latent vars: 100%|██████████| 6/6 [00:15<00:00, 2.65s/it] Predicting latent vars: 100%|██████████| 6/6 [00:13<00:00, 2.31s/it]
results = pd.DataFrame(results)
Visualizing the results, the test_statistic shows most genes have test_stat < 20, which makes them cis-regulated. There are some genes where the trans-model does statistically better at describing the data:
print(*results[results.significant].gene.values, sep = ', ')
TACO1, PLAT, CAPNS1, VCAM1, RORB, LY6A, FAM107B, COL17A1, TNFAIP8, SEL1L, IL1R2, PREX2, ABTB2
ax = sns.displot(data = results, x = 'test_statistic', bins = 10)
<seaborn.axisgrid.FacetGrid at 0x7ff1bb021890>
It can be useful to compare the probability of observing the data given both models to compare where the trans model performed better. We can compare the log-prob of the data given the cis and trans models like so:
delta_logp = anndata.AnnData(X = trans_model.get_logp() - cis_model.get_logp(), # the difference in prob of the trans prediction vs cis prediction. When X > 0, the trans model is doing better.
var = pd.DataFrame(index = cis_model.genes), #provide genes as columns
obs = joint_atac_data.obs, #copy obs from atac data
obsm = joint_atac_data.obsm)
delta_logp.X = np.where(delta_logp.X < 0, np.nan, delta_logp.X) #mask < 0 to make plot more readable
Scoring models: 100%|██████████| 39/39 [00:00<00:00, 144.72it/s] Scoring models: 100%|██████████| 39/39 [00:00<00:00, 194.02it/s]
sc.pl.umap(delta_logp, color = results.head(4).gene.values, color_map = 'Reds', na_color='white', add_outline=True, frameon=False, outline_color=('white','lightgrey'))
sc.pl.umap(joint_rna_data, color = results.head(4).gene.values, **RAW_UMAP)
Coming soon! I'll work on the install requirements for the Dynamic tracks package.
Here's a rundown on the steps for finding motif hits. I removed the downloading step for motifs and saved them with the package, and I also ran the motifs through a parser so they should be formatted correctly.
import logging
logging.basicConfig(level=logging.INFO)
logging.info('test') # turn on logging to see progress
INFO:root:test
atac_model.get_motif_hits_in_peaks('/Users/alynch/genomes/mm10/mm10.fa')
INFO:kladi.motif_scanning.moods_scan:Getting peak sequences ... 148304it [00:13, 11091.07it/s] INFO:kladi.motif_scanning.moods_scan:Scanning peaks for motif hits with p >= 5e-05 ... INFO:kladi.motif_scanning.moods_scan:Building motif background models ... INFO:kladi.motif_scanning.moods_scan:Starting scan ... INFO:kladi.motif_scanning.moods_scan:Found 1000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 2000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 3000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 4000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 5000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 6000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 7000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 8000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 9000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 10000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 11000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 12000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 13000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 14000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 15000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 16000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 17000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 18000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 19000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 20000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 21000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 22000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 23000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 24000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 25000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 26000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 27000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 28000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 29000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 30000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 31000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 32000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Found 33000000 motif hits ... INFO:kladi.motif_scanning.moods_scan:Formatting hits matrix ...
Save hits data
atac_model.save_hits_data('data/mouse_prostate/atac_model_data.pth')
# to load: atac_model.load_hits_data('data/mouse_prostate/atac_model_data.pth')
Filter factors in the hits data for those found in the expression data
atac_model.filter_factors(joint_rna_data.var_names)
List TFs by enrichment in topic:
atac_model.enrich_TFs(8)[:20]
Finding enrichments: 100%|██████████| 413/413 [00:01<00:00, 257.00it/s]
[('MA1104.2', 'GATA6', 3.5518861905786013e-144, 1.6467506658877005), ('MA0482.2', 'GATA4', 1.2606122867973607e-143, 1.6522788970078472), ('MA0036.3', 'GATA2', 4.595947933961744e-140, 1.6733525252400563), ('MA0037.3', 'GATA3', 9.894863993686855e-136, 1.6706712214179162), ('MA0851.1', 'FOXJ3', 2.957152669167175e-129, 1.462610474242792), ('MA0032.2', 'FOXC1', 1.5780332790124106e-116, 1.4577195639332938), ('MA0148.4', 'FOXA1', 8.758780748174128e-104, 1.480175093672573), ('MA0481.3', 'FOXP1', 4.0919008371602687e-103, 1.463839172902741), ('MA1606.1', 'FOXF1', 2.221070408134676e-100, 1.4397008130740787), ('MA1103.2', 'FOXK2', 1.9053891898580628e-93, 1.4724966644701105), ('MA0848.1', 'FOXO4', 1.9085179842322036e-91, 1.4661303817303075), ('MA0593.1', 'FOXP2', 1.84289197176993e-87, 1.3934729887289605), ('MA0852.2', 'FOXK1', 8.408089792885883e-83, 1.434627165311404), ('MA0614.1', 'FOXJ2', 2.023650415481049e-82, 1.4621263094375214), ('MA0157.2', 'FOXO3', 5.3473680775209625e-81, 1.489623305329279), ('MA1279.1', 'COG1', 4.317412731989247e-43, 1.2095407146898005), ('MA0040.1', 'FOXQ1', 4.543233628267037e-41, 1.2973901655150888), ('MA0480.1', 'FOXO1', 1.654603318170009e-40, 1.2534923743050739), ('MA1105.2', 'GRHL2', 1.0642802206687087e-35, 1.2787231325757487), ('MA1489.1', 'FOXN3', 6.40647894804195e-34, 1.491252398572033)]
Compare topic TF enrichments:
atac_model.plot_compare_module_enrichments(8,4, label_closeness=1.5, figsize=(10,10), pval_threshold=(1e-25,1e-25))
Finding enrichments: 100%|██████████| 413/413 [00:01<00:00, 259.79it/s] Finding enrichments: 100%|██████████| 413/413 [00:01<00:00, 263.75it/s]