Batch-wise causarray Analysis — Replogle-E-K562#
This tutorial demonstrates the batch fitting workflow of causarray on a subset of the Replogle et al. 2022 genome-wide CRISPRi screen in K562 cells (Replogle et al. 2022, Cell).
Dataset (tutorial subset)
Full dataset: 309 915 cells × 8 563 genes, 2 021 perturbations
Tutorial subset: top-200 most-abundant perturbations + 2 000 ctrl cells → 79 865 cells, created by
prep_tutorial_data.py
fit_gcate on 79 865 cells with 200 perturbations simultaneously would require fitting a GLM with 201 treatment columns, which is both numerically difficult and memory-intensive. gcate_lfc_batch avoids this by pairing each batch of 15 perturbations with a fixed pool of control cells, bounding peak memory to one batch at a time regardless of total pert count.Pipeline overview
replogle_subset.h5ad
|
v prep_causarray_data
Y, A, X (count matrix, treatment indicators, covariates)
|
v gcate_lfc_batch
df_res (tau, std, stat, pvalue, padj per gene x pert)
|
v visualization
volcano, discovery bar chart
[1]:
import sys
sys.path.insert(0, '../../..')
import time
import numpy as np
import pandas as pd
import scipy.sparse as sp
import matplotlib.pyplot as plt
import scanpy as sc
from causarray import prep_causarray_data, gcate_lfc_batch
from causarray.gcate import plot_r
import causarray
print('causarray version:', causarray.__version__)
causarray version: 0.0.6
1. Load data#
The subset was prepared by prep_tutorial_data.py. It contains the 200 perturbations with the most cells plus a random sample of 2 000 ctrl (non-targeting) cells.
[2]:
adata = sc.read_h5ad('replogle_subset.h5ad')
print(adata)
CTRL_LABEL = adata.uns['ctrl_label'] # 'non-targeting'
PERT_COL = adata.uns['pert_col'] # 'gene'
vc = adata.obs[PERT_COL].value_counts()
print(f'\nCtrl cells : {vc[CTRL_LABEL]}')
print(f'Pert cells : {vc.drop(CTRL_LABEL).sum():,} across {len(vc) - 1} perturbations')
print(f'Cells/pert : median {vc.drop(CTRL_LABEL).median():.0f}, '
f'range {vc.drop(CTRL_LABEL).min()}-{vc.drop(CTRL_LABEL).max()}')
AnnData object with n_obs × n_vars = 79865 × 8563
obs: 'batch', 'gene', 'gene_id', 'transcript', 'gene_transcript', 'guide_id', 'percent_mito', 'UMI_count', 'z_gemgroup_UMI', 'core_scale_factor', 'core_adjusted_UMI_count', 'disease', 'cancer', 'cell_line', 'sex', 'age', 'perturbation', 'organism', 'perturbation_type', 'tissue_type', 'ncounts', 'ngenes', 'nperts', 'percent_ribo'
var: 'chr', 'start', 'end', 'class', 'strand', 'length', 'in_matrix', 'mean', 'std', 'cv', 'fano', 'ensembl_id', 'ncounts', 'ncells'
uns: 'ctrl_label', 'pert_col'
Ctrl cells : 2000
Pert cells : 77,865 across 200 perturbations
Cells/pert : median 346, range 273-1996
2. Prepare causarray inputs#
[3]:
Y_raw = adata.X.toarray() if sp.issparse(adata.X) else np.array(adata.X)
Y = pd.DataFrame(Y_raw, columns=adata.var_names.tolist())
del Y_raw
A = (pd.get_dummies(adata.obs[PERT_COL].astype(str), drop_first=False)
.drop(columns=[CTRL_LABEL]))
Y, A, X, X_A = prep_causarray_data(Y, A)
print(f'Y : {Y.shape} (cells x genes)')
print(f'A : {A.shape} (cells x perturbations)')
print(f'X : {X.shape} X_A : {X_A.shape}')
Y : (79865, 8563) (cells x genes)
A : (79865, 200) (cells x perturbations)
X : (79865, 1) X_A : (79865, 2)
3. Number of latent factors#
We select r using the JIC criterion computed by estimate_r. The pre-computed results are loaded from replogle-r.csv (generated by estimate_r_replogle.py; re-run that script to reproduce them).
Why estimate ``r`` on control cells?estimate_rfits GCATE internally, which is expensive at full dataset scale (79 865 cells). The latent factors capture confounding variation that is present in the baseline (control) transcriptome, so a ctrl-priority subsample is both much faster and statistically equivalent to using all cells.estimate_r_replogle.pyuses 2 000 ctrl + 20 × ≤200 pert cells for this reason.You can replicate this with themax_cellsargument:df_r = estimate_r(Y, X, A, r_values, family='nb', max_cells=6000)
estimate_rwill automatically prioritise control cells when subsampling.
[4]:
df_r = pd.read_csv('replogle-r.csv')
fig = plot_r(df_r)
plt.tight_layout()
plt.show()
best_r = int(df_r.loc[df_r['JIC'].idxmin(), 'r'])
print(f'\nSelected r = {best_r} (min JIC)')
print(df_r.to_string(index=False))
Selected r = 30 (min JIC)
r deviance nu JIC
0 -13.524045 0.030448 -13.493596
5 -13.525350 0.037698 -13.487652
10 -13.526765 0.044947 -13.481817
15 -13.528364 0.052197 -13.476167
20 -13.530392 0.059447 -13.470945
25 -13.534008 0.066696 -13.467311
30 -13.599773 0.073946 -13.525827
4. Batch fitting with gcate_lfc_batch#
Key parameters:
Parameter |
Value |
Effect |
|---|---|---|
|
15 |
~15 perts processed per GCATE call |
|
2000 |
<= 2000 pert cells per batch (ctrl added on top) |
|
2000 |
Fixed ctrl subsample shared across all batches |
|
|
Resume from disk if interrupted |
With 200 perturbations and batch_size=15, gcate_lfc_batch internally computes n_batches = ceil(200 / 15) = 14 and then calls numpy.array_split to distribute the remainder evenly, giving 4 batches of 15 and 10 batches of 14 — no tiny tail batch. Each batch has at most 2000 ctrl + 2000 pert cells = 4000 cells.
[5]:
R = best_r
t0 = time.perf_counter()
df_res = gcate_lfc_batch(
Y, X, A, R,
W_A=X_A,
batch_size=15,
max_cells=2000,
n_ctrl=2000,
family='nb',
cache_path='replogle_results.h5',
random_state=0,
verbose=True,
gcate_kwargs=dict(
kwargs_es_1=dict(rel_tol=2e-4, max_iters=30),
kwargs_es_2=dict(rel_tol=2e-4, max_iters=30),
),
)
t_total = time.perf_counter() - t0
print(f'\nTotal wall time: {t_total/60:.1f} min')
print(f'Result shape : {df_res.shape}')
df_res.head()
'Pre-estimating dispersion on ctrl cell subsample...'
GCATE batches: 0%| | 0/14 [00:00<?, ?batch/s]
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches: 7%|▋ | 1/14 [02:36<33:50, 156.22s/batch]
'Batch 1/14: 15 perts, 4000 cells, 156.2s (avg 156.2s/batch, ETA 0.6h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches: 14%|█▍ | 2/14 [05:01<29:55, 149.65s/batch]
'Batch 2/14: 15 perts, 4000 cells, 145.0s (avg 150.6s/batch, ETA 0.5h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches: 21%|██▏ | 3/14 [07:25<27:00, 147.30s/batch]
'Batch 3/14: 15 perts, 4000 cells, 144.5s (avg 148.6s/batch, ETA 0.5h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches: 29%|██▊ | 4/14 [09:51<24:28, 146.80s/batch]
'Batch 4/14: 15 perts, 4000 cells, 146.0s (avg 148.0s/batch, ETA 0.4h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches: 36%|███▌ | 5/14 [12:15<21:50, 145.65s/batch]
'Batch 5/14: 14 perts, 4000 cells, 143.6s (avg 147.1s/batch, ETA 0.4h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches: 43%|████▎ | 6/14 [14:39<19:20, 145.02s/batch]
'Batch 6/14: 14 perts, 4000 cells, 143.8s (avg 146.5s/batch, ETA 0.3h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches: 50%|█████ | 7/14 [17:03<16:52, 144.69s/batch]
'Batch 7/14: 14 perts, 4000 cells, 144.0s (avg 146.2s/batch, ETA 0.3h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches: 57%|█████▋ | 8/14 [19:27<14:27, 144.56s/batch]
'Batch 8/14: 14 perts, 4000 cells, 144.3s (avg 145.9s/batch, ETA 0.2h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches: 64%|██████▍ | 9/14 [21:52<12:03, 144.68s/batch]
'Batch 9/14: 14 perts, 4000 cells, 144.9s (avg 145.8s/batch, ETA 0.2h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches: 71%|███████▏ | 10/14 [24:19<09:41, 145.31s/batch]
'Batch 10/14: 14 perts, 4000 cells, 146.7s (avg 145.9s/batch, ETA 0.2h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches: 79%|███████▊ | 11/14 [26:43<07:15, 145.05s/batch]
'Batch 11/14: 14 perts, 4000 cells, 144.5s (avg 145.8s/batch, ETA 0.1h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches: 86%|████████▌ | 12/14 [29:06<04:48, 144.41s/batch]
'Batch 12/14: 14 perts, 4000 cells, 142.9s (avg 145.5s/batch, ETA 0.1h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches: 93%|█████████▎| 13/14 [31:28<02:23, 143.74s/batch]
'Batch 13/14: 14 perts, 4000 cells, 142.2s (avg 145.3s/batch, ETA 0.0h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches: 100%|██████████| 14/14 [33:55<00:00, 145.40s/batch]
'Batch 14/14: 14 perts, 4000 cells, 146.8s (avg 145.4s/batch, ETA 0.0h)'
'Estimating LFC...'
{'a': 15, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
'class_weight': 'balanced',
'fit_intercept': False,
'random_state': 0,
'verbose': False}
{'offset': array([-0.00384929, 0.00172282, 0.00193157, ..., 0.00084838,
-0.0056416 , 0.00156955], shape=(4000,)),
'random_state': 0,
'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 15/15 [00:14<00:00, 1.05it/s]
'Estimating LFC...'
{'a': 15, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
'class_weight': 'balanced',
'fit_intercept': False,
'random_state': 0,
'verbose': False}
{'offset': array([ 0.00171867, -0.00482039, -0.00717121, ..., 0.00177398,
0.00055709, 0.00202624], shape=(4000,)),
'random_state': 0,
'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 15/15 [00:13<00:00, 1.08it/s]
'Estimating LFC...'
{'a': 15, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
'class_weight': 'balanced',
'fit_intercept': False,
'random_state': 0,
'verbose': False}
{'offset': array([ 0.00137385, -0.00517628, -0.00775641, ..., 0.00135865,
0.00016659, 0.00190154], shape=(4000,)),
'random_state': 0,
'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 15/15 [00:13<00:00, 1.09it/s]
'Estimating LFC...'
{'a': 15, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
'class_weight': 'balanced',
'fit_intercept': False,
'random_state': 0,
'verbose': False}
{'offset': array([ 0.00152472, -0.00516711, -0.00758565, ..., 0.00153186,
0.00045201, 0.00149752], shape=(4000,)),
'random_state': 0,
'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 15/15 [00:13<00:00, 1.09it/s]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
'class_weight': 'balanced',
'fit_intercept': False,
'random_state': 0,
'verbose': False}
{'offset': array([ 0.00144432, -0.00588193, -0.00854897, ..., 0.00274775,
0.00117871, 0.00019382], shape=(4000,)),
'random_state': 0,
'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00, 1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
'class_weight': 'balanced',
'fit_intercept': False,
'random_state': 0,
'verbose': False}
{'offset': array([-0.00564738, 0.00106602, -0.00698854, ..., 0.00252708,
0.00109336, -0.0002114 ], shape=(4000,)),
'random_state': 0,
'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00, 1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
'class_weight': 'balanced',
'fit_intercept': False,
'random_state': 0,
'verbose': False}
{'offset': array([ 0.00182445, -0.00505226, -0.00739483, ..., 0.00185949,
0.0007161 , 0.00313411], shape=(4000,)),
'random_state': 0,
'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00, 1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
'class_weight': 'balanced',
'fit_intercept': False,
'random_state': 0,
'verbose': False}
{'offset': array([ 0.00147441, -0.0058698 , -0.00777271, ..., 0.00278432,
0.00144021, 0.00027878], shape=(4000,)),
'random_state': 0,
'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00, 1.11s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
'class_weight': 'balanced',
'fit_intercept': False,
'random_state': 0,
'verbose': False}
{'offset': array([ 0.00187935, -0.00478473, -0.0050402 , ..., 0.00173506,
0.00046157, 0.00227872], shape=(4000,)),
'random_state': 0,
'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00, 1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
'class_weight': 'balanced',
'fit_intercept': False,
'random_state': 0,
'verbose': False}
{'offset': array([ 0.00098952, -0.00562532, -0.00672674, ..., 0.00201541,
0.00113671, -0.00010183], shape=(4000,)),
'random_state': 0,
'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00, 1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
'class_weight': 'balanced',
'fit_intercept': False,
'random_state': 0,
'verbose': False}
{'offset': array([ 0.00192672, 0.00165499, -0.00620323, ..., 0.00294756,
0.00182595, 0.00075213], shape=(4000,)),
'random_state': 0,
'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00, 1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
'class_weight': 'balanced',
'fit_intercept': False,
'random_state': 0,
'verbose': False}
{'offset': array([ 0.00206825, 0.00082535, -0.00507758, ..., -0.00557712,
-0.00421316, 0.00059294], shape=(4000,)),
'random_state': 0,
'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00, 1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
'class_weight': 'balanced',
'fit_intercept': False,
'random_state': 0,
'verbose': False}
{'offset': array([-0.00065913, -0.00758703, 0.00024738, ..., -0.00181731,
-0.00093586, 0.00043236], shape=(4000,)),
'random_state': 0,
'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00, 1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
'class_weight': 'balanced',
'fit_intercept': False,
'random_state': 0,
'verbose': False}
{'offset': array([ 0.00170334, -0.004485 , 0.00183328, ..., 0.00178951,
0.00073638, -0.00187221], shape=(4000,)),
'random_state': 0,
'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00, 1.12s/it]
Total wall time: 149.2 min
Result shape : (1712600, 11)
[5]:
| gene_names | tau | std | stat | rej | pvalue | padj | pvalue_emp_null_adj | padj_emp_null_adj | trt | batch | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | LINC01409 | -0.265192 | 0.586343 | -0.452281 | 0.0 | 0.652068 | 0.902596 | 0.701893 | 0.978885 | AATF | 0 |
| 1 | LINC01128 | -0.419937 | 2.613165 | -0.160700 | 0.0 | 0.872674 | 0.943767 | 0.451633 | 0.978885 | AATF | 0 |
| 2 | NOC2L | -0.041192 | 0.077027 | -0.534772 | 0.0 | 0.593921 | 0.884325 | 0.781044 | 0.984638 | AATF | 0 |
| 3 | KLHL17 | -0.254840 | 0.581965 | -0.437896 | 0.0 | 0.662441 | 0.906762 | 0.688398 | 0.978885 | AATF | 0 |
| 4 | HES4 | -0.308193 | 0.271868 | -1.133613 | 0.0 | 0.259441 | 0.750317 | 0.628480 | 0.978885 | AATF | 0 |
5. Results#
5.1 Discovery summary#
[6]:
FDR = 0.05
sig = df_res[df_res['padj'] < FDR]
print(f'Significant (padj < {FDR}): {len(sig):,} gene x pert pairs')
print(f' Perturbations with >= 1 hit: {sig["trt"].nunique()}')
print(f' Unique genes affected : {sig["gene_names"].nunique():,}')
disc_per_pert = sig.groupby('trt').size().sort_values(ascending=False)
print(f'\nTop-10 perts by discovery count:')
print(disc_per_pert.head(10).to_string())
Significant (padj < 0.05): 221,657 gene x pert pairs
Perturbations with >= 1 hit: 172
Unique genes affected : 8,133
Top-10 perts by discovery count:
trt
KIF11 4904
RPL3 4858
DONSON 4567
QARS 4096
SLC39A9 4015
EIF3H 3925
PMPCB 3878
NUP205 3614
HSPA5 3540
SSRP1 3432
5.2 Volcano plot (all perturbations combined)#
[7]:
fig, ax = plt.subplots(figsize=(7, 5))
colors = np.where(df_res['padj'] < FDR, '#e74c3c', '#aaaaaa')
ax.scatter(
df_res['tau'], -np.log10(df_res['pvalue'].clip(1e-300)),
c=colors, s=1, alpha=0.4, rasterized=True,
)
ax.axhline(-np.log10(0.05 / len(df_res)), color='navy', lw=0.8, ls='--',
label='Bonferroni')
ax.set_ylim(0, 10)
ax.set_xlabel('Estimated LFC (tau)')
ax.set_ylabel('-log10(p-value)')
ax.set_title('Volcano plot — Replogle-E-K562 (top-200 perts)')
ax.legend(fontsize=8)
plt.tight_layout()
plt.show()
5.3 Discovery count per perturbation#
Each bar shows the number of significant genes (BH-adjusted p-value < 0.05) for one perturbation. Wide variation in bar heights reflects genuine differences in perturbation strength rather than technical artefacts, since causarray deconfounds shared latent variation before testing.
[8]:
top_n = 30
top_disc = disc_per_pert.head(top_n)
fig, ax = plt.subplots(figsize=(10, 4))
top_disc.plot(kind='bar', ax=ax, color='steelblue', edgecolor='none')
ax.set_ylabel(f'Significant genes (padj < {FDR})')
ax.set_title(f'Top-{top_n} perturbations by discovery count')
ax.tick_params(axis='x', rotation=45, labelsize=8)
plt.tight_layout()
plt.show()