This page was generated from docs/tutorial/replogle/replogle-py.ipynb. Interactive online version: Binder badge.

Batch-wise causarray Analysis — Replogle-E-K562#

This tutorial demonstrates the batch fitting workflow of causarray on a subset of the Replogle et al. 2022 genome-wide CRISPRi screen in K562 cells (Replogle et al. 2022, Cell).

Dataset (tutorial subset)

  • Full dataset: 309 915 cells × 8 563 genes, 2 021 perturbations

  • Tutorial subset: top-200 most-abundant perturbations + 2 000 ctrl cells → 79 865 cells, created by prep_tutorial_data.py

Why batch fitting?
Running fit_gcate on 79 865 cells with 200 perturbations simultaneously would require fitting a GLM with 201 treatment columns, which is both numerically difficult and memory-intensive. gcate_lfc_batch avoids this by pairing each batch of 15 perturbations with a fixed pool of control cells, bounding peak memory to one batch at a time regardless of total pert count.

Pipeline overview

replogle_subset.h5ad
      |
      v  prep_causarray_data
   Y, A, X                  (count matrix, treatment indicators, covariates)
      |
      v  gcate_lfc_batch
   df_res                   (tau, std, stat, pvalue, padj per gene x pert)
      |
      v  visualization
   volcano, discovery bar chart
[1]:
import sys
sys.path.insert(0, '../../..')

import time
import numpy as np
import pandas as pd
import scipy.sparse as sp
import matplotlib.pyplot as plt
import scanpy as sc

from causarray import prep_causarray_data, gcate_lfc_batch
from causarray.gcate import plot_r

import causarray
print('causarray version:', causarray.__version__)

causarray version: 0.0.6

1. Load data#

The subset was prepared by prep_tutorial_data.py. It contains the 200 perturbations with the most cells plus a random sample of 2 000 ctrl (non-targeting) cells.

[2]:
adata = sc.read_h5ad('replogle_subset.h5ad')
print(adata)

CTRL_LABEL = adata.uns['ctrl_label']   # 'non-targeting'
PERT_COL   = adata.uns['pert_col']     # 'gene'

vc = adata.obs[PERT_COL].value_counts()
print(f'\nCtrl cells  : {vc[CTRL_LABEL]}')
print(f'Pert cells  : {vc.drop(CTRL_LABEL).sum():,}  across {len(vc) - 1} perturbations')
print(f'Cells/pert  : median {vc.drop(CTRL_LABEL).median():.0f}, '
      f'range {vc.drop(CTRL_LABEL).min()}-{vc.drop(CTRL_LABEL).max()}')

AnnData object with n_obs × n_vars = 79865 × 8563
    obs: 'batch', 'gene', 'gene_id', 'transcript', 'gene_transcript', 'guide_id', 'percent_mito', 'UMI_count', 'z_gemgroup_UMI', 'core_scale_factor', 'core_adjusted_UMI_count', 'disease', 'cancer', 'cell_line', 'sex', 'age', 'perturbation', 'organism', 'perturbation_type', 'tissue_type', 'ncounts', 'ngenes', 'nperts', 'percent_ribo'
    var: 'chr', 'start', 'end', 'class', 'strand', 'length', 'in_matrix', 'mean', 'std', 'cv', 'fano', 'ensembl_id', 'ncounts', 'ncells'
    uns: 'ctrl_label', 'pert_col'

Ctrl cells  : 2000
Pert cells  : 77,865  across 200 perturbations
Cells/pert  : median 346, range 273-1996

2. Prepare causarray inputs#

[3]:
Y_raw = adata.X.toarray() if sp.issparse(adata.X) else np.array(adata.X)
Y = pd.DataFrame(Y_raw, columns=adata.var_names.tolist())
del Y_raw

A = (pd.get_dummies(adata.obs[PERT_COL].astype(str), drop_first=False)
       .drop(columns=[CTRL_LABEL]))

Y, A, X, X_A = prep_causarray_data(Y, A)

print(f'Y : {Y.shape}  (cells x genes)')
print(f'A : {A.shape}  (cells x perturbations)')
print(f'X : {X.shape}  X_A : {X_A.shape}')

Y : (79865, 8563)  (cells x genes)
A : (79865, 200)  (cells x perturbations)
X : (79865, 1)  X_A : (79865, 2)

3. Number of latent factors#

We select r using the JIC criterion computed by estimate_r. The pre-computed results are loaded from replogle-r.csv (generated by estimate_r_replogle.py; re-run that script to reproduce them).

Why estimate ``r`` on control cells?
estimate_r fits GCATE internally, which is expensive at full dataset scale (79 865 cells). The latent factors capture confounding variation that is present in the baseline (control) transcriptome, so a ctrl-priority subsample is both much faster and statistically equivalent to using all cells.
estimate_r_replogle.py uses 2 000 ctrl + 20 × ≤200 pert cells for this reason.
You can replicate this with the max_cells argument:
df_r = estimate_r(Y, X, A, r_values, family='nb', max_cells=6000)

estimate_r will automatically prioritise control cells when subsampling.

[4]:
df_r = pd.read_csv('replogle-r.csv')
fig = plot_r(df_r)
plt.tight_layout()
plt.show()

best_r = int(df_r.loc[df_r['JIC'].idxmin(), 'r'])
print(f'\nSelected r = {best_r}  (min JIC)')
print(df_r.to_string(index=False))

../../_images/tutorial_replogle_replogle-py_7_0.png

Selected r = 30  (min JIC)
 r   deviance       nu        JIC
 0 -13.524045 0.030448 -13.493596
 5 -13.525350 0.037698 -13.487652
10 -13.526765 0.044947 -13.481817
15 -13.528364 0.052197 -13.476167
20 -13.530392 0.059447 -13.470945
25 -13.534008 0.066696 -13.467311
30 -13.599773 0.073946 -13.525827

4. Batch fitting with gcate_lfc_batch#

Key parameters:

Parameter

Value

Effect

batch_size

15

~15 perts processed per GCATE call

max_cells

2000

<= 2000 pert cells per batch (ctrl added on top)

n_ctrl

2000

Fixed ctrl subsample shared across all batches

cache_path

'replogle_results.h5'

Resume from disk if interrupted

With 200 perturbations and batch_size=15, gcate_lfc_batch internally computes n_batches = ceil(200 / 15) = 14 and then calls numpy.array_split to distribute the remainder evenly, giving 4 batches of 15 and 10 batches of 14 — no tiny tail batch. Each batch has at most 2000 ctrl + 2000 pert cells = 4000 cells.

[5]:
R = best_r

t0 = time.perf_counter()
df_res = gcate_lfc_batch(
    Y, X, A, R,
    W_A=X_A,
    batch_size=15,
    max_cells=2000,
    n_ctrl=2000,
    family='nb',
    cache_path='replogle_results.h5',
    random_state=0,
    verbose=True,
    gcate_kwargs=dict(
        kwargs_es_1=dict(rel_tol=2e-4, max_iters=30),
        kwargs_es_2=dict(rel_tol=2e-4, max_iters=30),
    ),
)
t_total = time.perf_counter() - t0

print(f'\nTotal wall time: {t_total/60:.1f} min')
print(f'Result shape   : {df_res.shape}')
df_res.head()

'Pre-estimating dispersion on ctrl cell subsample...'
GCATE batches:   0%|          | 0/14 [00:00<?, ?batch/s]
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches:   7%|▋         | 1/14 [02:36<33:50, 156.22s/batch]
'Batch 1/14: 15 perts, 4000 cells, 156.2s  (avg 156.2s/batch, ETA 0.6h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches:  14%|█▍        | 2/14 [05:01<29:55, 149.65s/batch]
'Batch 2/14: 15 perts, 4000 cells, 145.0s  (avg 150.6s/batch, ETA 0.5h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches:  21%|██▏       | 3/14 [07:25<27:00, 147.30s/batch]
'Batch 3/14: 15 perts, 4000 cells, 144.5s  (avg 148.6s/batch, ETA 0.5h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches:  29%|██▊       | 4/14 [09:51<24:28, 146.80s/batch]
'Batch 4/14: 15 perts, 4000 cells, 146.0s  (avg 148.0s/batch, ETA 0.4h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches:  36%|███▌      | 5/14 [12:15<21:50, 145.65s/batch]
'Batch 5/14: 14 perts, 4000 cells, 143.6s  (avg 147.1s/batch, ETA 0.4h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches:  43%|████▎     | 6/14 [14:39<19:20, 145.02s/batch]
'Batch 6/14: 14 perts, 4000 cells, 143.8s  (avg 146.5s/batch, ETA 0.3h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches:  50%|█████     | 7/14 [17:03<16:52, 144.69s/batch]
'Batch 7/14: 14 perts, 4000 cells, 144.0s  (avg 146.2s/batch, ETA 0.3h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches:  57%|█████▋    | 8/14 [19:27<14:27, 144.56s/batch]
'Batch 8/14: 14 perts, 4000 cells, 144.3s  (avg 145.9s/batch, ETA 0.2h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches:  64%|██████▍   | 9/14 [21:52<12:03, 144.68s/batch]
'Batch 9/14: 14 perts, 4000 cells, 144.9s  (avg 145.8s/batch, ETA 0.2h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches:  71%|███████▏  | 10/14 [24:19<09:41, 145.31s/batch]
'Batch 10/14: 14 perts, 4000 cells, 146.7s  (avg 145.9s/batch, ETA 0.2h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches:  79%|███████▊  | 11/14 [26:43<07:15, 145.05s/batch]
'Batch 11/14: 14 perts, 4000 cells, 144.5s  (avg 145.8s/batch, ETA 0.1h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches:  86%|████████▌ | 12/14 [29:06<04:48, 144.41s/batch]
'Batch 12/14: 14 perts, 4000 cells, 142.9s  (avg 145.5s/batch, ETA 0.1h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches:  93%|█████████▎| 13/14 [31:28<02:23, 143.74s/batch]
'Batch 13/14: 14 perts, 4000 cells, 142.2s  (avg 145.3s/batch, ETA 0.0h)'
'Fitting GCATE (step 1)...'
'Fitting GCATE (step 2)...'
GCATE batches: 100%|██████████| 14/14 [33:55<00:00, 145.40s/batch]
'Batch 14/14: 14 perts, 4000 cells, 146.8s  (avg 145.4s/batch, ETA 0.0h)'
'Estimating LFC...'
{'a': 15, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
 'class_weight': 'balanced',
 'fit_intercept': False,
 'random_state': 0,
 'verbose': False}
{'offset': array([-0.00384929,  0.00172282,  0.00193157, ...,  0.00084838,
       -0.0056416 ,  0.00156955], shape=(4000,)),
 'random_state': 0,
 'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'

'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 15/15 [00:14<00:00,  1.05it/s]
'Estimating LFC...'
{'a': 15, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
 'class_weight': 'balanced',
 'fit_intercept': False,
 'random_state': 0,
 'verbose': False}
{'offset': array([ 0.00171867, -0.00482039, -0.00717121, ...,  0.00177398,
        0.00055709,  0.00202624], shape=(4000,)),
 'random_state': 0,
 'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 15/15 [00:13<00:00,  1.08it/s]
'Estimating LFC...'
{'a': 15, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
 'class_weight': 'balanced',
 'fit_intercept': False,
 'random_state': 0,
 'verbose': False}
{'offset': array([ 0.00137385, -0.00517628, -0.00775641, ...,  0.00135865,
        0.00016659,  0.00190154], shape=(4000,)),
 'random_state': 0,
 'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 15/15 [00:13<00:00,  1.09it/s]
'Estimating LFC...'
{'a': 15, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
 'class_weight': 'balanced',
 'fit_intercept': False,
 'random_state': 0,
 'verbose': False}
{'offset': array([ 0.00152472, -0.00516711, -0.00758565, ...,  0.00153186,
        0.00045201,  0.00149752], shape=(4000,)),
 'random_state': 0,
 'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 15/15 [00:13<00:00,  1.09it/s]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
 'class_weight': 'balanced',
 'fit_intercept': False,
 'random_state': 0,
 'verbose': False}
{'offset': array([ 0.00144432, -0.00588193, -0.00854897, ...,  0.00274775,
        0.00117871,  0.00019382], shape=(4000,)),
 'random_state': 0,
 'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00,  1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
 'class_weight': 'balanced',
 'fit_intercept': False,
 'random_state': 0,
 'verbose': False}
{'offset': array([-0.00564738,  0.00106602, -0.00698854, ...,  0.00252708,
        0.00109336, -0.0002114 ], shape=(4000,)),
 'random_state': 0,
 'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00,  1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
 'class_weight': 'balanced',
 'fit_intercept': False,
 'random_state': 0,
 'verbose': False}
{'offset': array([ 0.00182445, -0.00505226, -0.00739483, ...,  0.00185949,
        0.0007161 ,  0.00313411], shape=(4000,)),
 'random_state': 0,
 'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00,  1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
 'class_weight': 'balanced',
 'fit_intercept': False,
 'random_state': 0,
 'verbose': False}
{'offset': array([ 0.00147441, -0.0058698 , -0.00777271, ...,  0.00278432,
        0.00144021,  0.00027878], shape=(4000,)),
 'random_state': 0,
 'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00,  1.11s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
 'class_weight': 'balanced',
 'fit_intercept': False,
 'random_state': 0,
 'verbose': False}
{'offset': array([ 0.00187935, -0.00478473, -0.0050402 , ...,  0.00173506,
        0.00046157,  0.00227872], shape=(4000,)),
 'random_state': 0,
 'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00,  1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
 'class_weight': 'balanced',
 'fit_intercept': False,
 'random_state': 0,
 'verbose': False}
{'offset': array([ 0.00098952, -0.00562532, -0.00672674, ...,  0.00201541,
        0.00113671, -0.00010183], shape=(4000,)),
 'random_state': 0,
 'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00,  1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
 'class_weight': 'balanced',
 'fit_intercept': False,
 'random_state': 0,
 'verbose': False}
{'offset': array([ 0.00192672,  0.00165499, -0.00620323, ...,  0.00294756,
        0.00182595,  0.00075213], shape=(4000,)),
 'random_state': 0,
 'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00,  1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
 'class_weight': 'balanced',
 'fit_intercept': False,
 'random_state': 0,
 'verbose': False}
{'offset': array([ 0.00206825,  0.00082535, -0.00507758, ..., -0.00557712,
       -0.00421316,  0.00059294], shape=(4000,)),
 'random_state': 0,
 'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00,  1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
 'class_weight': 'balanced',
 'fit_intercept': False,
 'random_state': 0,
 'verbose': False}
{'offset': array([-0.00065913, -0.00758703,  0.00024738, ..., -0.00181731,
       -0.00093586,  0.00043236], shape=(4000,)),
 'random_state': 0,
 'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00,  1.10s/it]
'Estimating LFC...'
{'a': 14, 'd': 31, 'd_A': 32, 'estimands': 'LFC', 'n': 4000, 'p': 8563}
{'C': 1.0,
 'class_weight': 'balanced',
 'fit_intercept': False,
 'random_state': 0,
 'verbose': False}
{'offset': array([ 0.00170334, -0.004485  ,  0.00183328, ...,  0.00178951,
        0.00073638, -0.00187221], shape=(4000,)),
 'random_state': 0,
 'verbose': True}
'Fit propensity score models...'
'Fit outcome models...'
'Fitting nb GLM (fast)...'
'Estimating AIPW mean...'
100%|██████████| 14/14 [00:15<00:00,  1.12s/it]

Total wall time: 149.2 min
Result shape   : (1712600, 11)
[5]:
gene_names tau std stat rej pvalue padj pvalue_emp_null_adj padj_emp_null_adj trt batch
0 LINC01409 -0.265192 0.586343 -0.452281 0.0 0.652068 0.902596 0.701893 0.978885 AATF 0
1 LINC01128 -0.419937 2.613165 -0.160700 0.0 0.872674 0.943767 0.451633 0.978885 AATF 0
2 NOC2L -0.041192 0.077027 -0.534772 0.0 0.593921 0.884325 0.781044 0.984638 AATF 0
3 KLHL17 -0.254840 0.581965 -0.437896 0.0 0.662441 0.906762 0.688398 0.978885 AATF 0
4 HES4 -0.308193 0.271868 -1.133613 0.0 0.259441 0.750317 0.628480 0.978885 AATF 0

5. Results#

5.1 Discovery summary#

[6]:
FDR = 0.05
sig = df_res[df_res['padj'] < FDR]

print(f'Significant (padj < {FDR}): {len(sig):,} gene x pert pairs')
print(f'  Perturbations with >= 1 hit: {sig["trt"].nunique()}')
print(f'  Unique genes affected      : {sig["gene_names"].nunique():,}')

disc_per_pert = sig.groupby('trt').size().sort_values(ascending=False)
print(f'\nTop-10 perts by discovery count:')
print(disc_per_pert.head(10).to_string())

Significant (padj < 0.05): 221,657 gene x pert pairs
  Perturbations with >= 1 hit: 172
  Unique genes affected      : 8,133

Top-10 perts by discovery count:
trt
KIF11      4904
RPL3       4858
DONSON     4567
QARS       4096
SLC39A9    4015
EIF3H      3925
PMPCB      3878
NUP205     3614
HSPA5      3540
SSRP1      3432

5.2 Volcano plot (all perturbations combined)#

[7]:
fig, ax = plt.subplots(figsize=(7, 5))

colors = np.where(df_res['padj'] < FDR, '#e74c3c', '#aaaaaa')
ax.scatter(
    df_res['tau'], -np.log10(df_res['pvalue'].clip(1e-300)),
    c=colors, s=1, alpha=0.4, rasterized=True,
)
ax.axhline(-np.log10(0.05 / len(df_res)), color='navy', lw=0.8, ls='--',
           label='Bonferroni')
ax.set_ylim(0, 10)
ax.set_xlabel('Estimated LFC (tau)')
ax.set_ylabel('-log10(p-value)')
ax.set_title('Volcano plot — Replogle-E-K562 (top-200 perts)')
ax.legend(fontsize=8)
plt.tight_layout()
plt.show()

../../_images/tutorial_replogle_replogle-py_13_0.png

5.3 Discovery count per perturbation#

Each bar shows the number of significant genes (BH-adjusted p-value < 0.05) for one perturbation. Wide variation in bar heights reflects genuine differences in perturbation strength rather than technical artefacts, since causarray deconfounds shared latent variation before testing.

[8]:
top_n = 30
top_disc = disc_per_pert.head(top_n)

fig, ax = plt.subplots(figsize=(10, 4))
top_disc.plot(kind='bar', ax=ax, color='steelblue', edgecolor='none')
ax.set_ylabel(f'Significant genes (padj < {FDR})')
ax.set_title(f'Top-{top_n} perturbations by discovery count')
ax.tick_params(axis='x', rotation=45, labelsize=8)
plt.tight_layout()
plt.show()

../../_images/tutorial_replogle_replogle-py_15_0.png