Doubly-robust semiparametric inference#
LFC() estimates per-gene log-fold changes using a doubly-robust AIPW
estimator. It requires the augmented covariate matrix W = [X | U] where
U are the latent factors from fit_gcate(), and produces a DataFrame
with effect estimates, standard errors, and BH-adjusted p-values.
For screens with hundreds of perturbations use gcate_lfc_batch(), which
runs GCATE and LFC in batches to keep peak memory bounded.
- causarray.DR_learner.LFC(Y, W, A, W_A=None, family='nb', offset=False, Y_hat=None, pi_hat=None, cross_est=False, mask=None, usevar='unequal', thres_min=0.01, thres_diff=0.01, eps_var=0.0001, fdx=False, fdx_alpha=0.05, fdx_c=0.1, verbose=False, backend: str = 'auto', **kwargs)#
Estimate log-fold changes of treatment effects (LFCs) using AIPW.
Fits a doubly-robust AIPW estimator for the log-ratio of counterfactual means E[Y(1)] / E[Y(0)]. Call this after
fit_gcate()to incorporate estimated latent factors into the covariate matrixW.- Parameters:
- Yarray, shape (n, p)
Count matrix of outcomes.
- Warray, shape (n, d)
Covariate matrix, typically
[X | U]whereUare the latent factors from GCATE.- Aarray, shape (n, a)
Binary treatment indicator matrix.
- W_Aarray or None, shape (n, d_A)
Covariate matrix for the propensity model. If
None,Wis used.- familystr
GLM family for the outcome model:
'nb'(default) or'poisson'.- offsetbool or array-like
Log-scale offset for the outcome model.
Truecomputes size factors automatically;FalseorNonedisables the offset.- Y_hatarray or None, shape (n, p, a, 2)
Pre-computed counterfactual predictions. When provided, cross-fitting is skipped.
- pi_hatarray or None, shape (n, a)
Pre-computed propensity scores. When provided, propensity fitting is skipped.
- cross_estbool
Whether to use cross-estimation for nuisance parameters.
- maskarray or None, shape (n, a)
Boolean mask indicating which cells are used for the final estimand computation. Does not affect cross-fitting or propensity estimation.
- usevarstr
Variance estimator for the AIPW pseudo-outcomes:
'unequal'(default, v0.0.6+): Welch variances₀²/n₀ + s₁²/n₁with Welch-Satterthwaite degrees of freedom; p-values use the t-distribution.'pooled': pooled-variance estimator(s² + eps_var) / n.
Changed in version 0.0.6: Default changed from
'pooled'to'unequal'. The'unequal'formula was also corrected from(s₀²/n₀ + s₁²/n₁)/2to the standard Welch form, which shrinks t-statistics by ≈ √2 relative to v0.0.5. Passusevar='pooled'to recover pre-v0.0.6 behaviour.- thres_minfloat
Genes whose maximum counterfactual mean is below this threshold are excluded (reported as
tau=0,padj=NaN).- thres_difffloat
Genes whose counterfactual means differ by less than this value are excluded.
- eps_varfloat
Small constant added to per-arm variances to prevent division by zero.
- fdxbool
Whether to apply FDX control (
P(FDP > fdx_c) < fdx_alpha).- fdx_alphafloat
Significance level for FDX control.
- fdx_cfloat
FDP threshold for FDX control.
- verbosebool
Print progress information.
- backendstr
GLM backend:
"auto"(default),"fast"(force crispyx), or"original"(force statsmodels).- **kwargs
Additional arguments forwarded to the GLM fitting functions.
- Returns:
- df_resDataFrame
Test results with columns
gene_names,tau,std,stat,rej,pvalue,padj,pvalue_emp_null_adj,padj_emp_null_adj(andtrtwhenAhas multiple columns).
- causarray.DR_learner.LFC_batch(*args, **kwargs)#
Deprecated alias for
gcate_lfc_batch().Deprecated since version Use:
gcate_lfc_batchinstead.LFC_batchwill be removed in a future release.
- causarray.DR_learner.VIM(eta_est, X, id_covs, **kwargs)#
Estimate variable importance measures (VIM) for heterogeneous treatment effects.
Decomposes treatment effect variance into components explained by each covariate using conditional average treatment effect (CATE) regression.
- Parameters:
- eta_estarray, shape (n, p)
Influence function values from
LFC()orcompute_causal_estimand().- Xarray, shape (n, d)
Covariate matrix.
- id_covsint or array-like of int
Column indices of
Xto compute VIM for. An integerkis treated asrange(k).
- Returns:
- estimationdict
Dictionary with keys:
'CATE','CATE_lower','CATE_upper'array, shape (n_covs, n, p)Conditional average treatment effect and pointwise confidence band.
'VTE'array, shape (p,)Total variance of the treatment effect (marginal).
'CVTE'array, shape (n_covs, p)Conditional variance of the treatment effect given each covariate.
'VIM_mean'array, shape (n_covs, p)VIM point estimate
CVTE / VTE - 1for each covariate and gene.'VIM_sd'array, shape (n_covs, p)Standard deviation of the VIM estimate.
- causarray.DR_learner.compute_causal_estimand(estimand, Y, W, A, W_A=None, family='nb', offset=False, Y_hat=None, pi_hat=None, mask=None, fdx=False, fdx_B=1000, fdx_alpha=0.05, fdx_c=0.1, verbose=False, random_state=0, backend: str = 'auto', **kwargs)#
Estimate causal treatment effects using AIPW with a user-supplied estimand.
- Parameters:
- estimandcallable
Function that maps influence function values
(etas, A)to(eta_est, tau_est, var_est[, df_eff]). SeeLFC()for an example implementation.- Yarray, shape (n, p)
Count matrix of outcomes.
- Warray, shape (n, d)
Covariate matrix (including latent factors from GCATE).
- Aarray, shape (n, a)
Binary treatment indicator matrix.
- W_Aarray or None, shape (n, d_A)
Covariate matrix for the propensity model. If
None,Wis used.- familystr
GLM family for the outcome model:
'nb'(default) or'poisson'.- offsetbool or array-like
Log-scale offset for the outcome model.
Truecomputes size factors automatically;FalseorNonedisables the offset.- Y_hatarray or None, shape (n, p, a, 2)
Pre-computed counterfactual predictions. When provided, cross-fitting is skipped.
- pi_hatarray or None, shape (n, a)
Pre-computed propensity scores. When provided, propensity fitting is skipped.
- maskarray or None, shape (n, a)
Boolean mask indicating which cells are used for the final estimand computation. Does not affect cross-fitting or propensity estimation.
- fdxbool
Whether to apply FDX control (
P(FDP > fdx_c) < fdx_alpha).- fdx_Bint
Number of bootstrap samples for FDX control.
- fdx_alphafloat
Significance level for FDX control.
- fdx_cfloat
FDP threshold for FDX control.
- backendstr
GLM backend:
"auto"(default),"fast"(force crispyx), or"original"(force statsmodels).- verbosebool
Print progress information.
- **kwargs
Additional arguments forwarded to the GLM fitting functions.
- Returns:
- df_resDataFrame
Test results with columns
gene_names,tau,std,stat,rej,pvalue,padj,pvalue_emp_null_adj,padj_emp_null_adj(andtrtwhenAhas multiple columns).
- causarray.DR_learner.gcate_lfc_batch(Y, X, A, r, W_A=None, batch_size=10, n_batches=None, max_cells=2000, n_ctrl=2000, family='nb', offset=True, warm_start_U=False, cache_path=None, random_state=0, verbose=False, gcate_kwargs=None, lfc_kwargs=None, **kwargs)#
Batch-wise GCATE + doubly-robust LFC estimation.
Partitions perturbations into chunks of
batch_size, runsfit_gcate_batch()to estimate per-batch latent confounders, then callsLFC()on each batch independently. All large intermediate arrays (res_1,res_2,Y_hat,pi_hat) are freed immediately after each batch so that peak memory is bounded by one batch’s worth of data regardless of the total number of perturbations.Results can optionally be cached to an HDF5 file (
cache_path) so that interrupted runs can be resumed without re-processing completed batches.- Parameters:
- Yarray-like or DataFrame, shape (n, p)
Count matrix.
- Xarray, shape (n, d)
Covariate matrix (intercept column should be included).
- Aarray-like or DataFrame, shape (n, a)
Binary treatment indicator matrix; control cells have all-zero rows.
- rint
Number of latent factors.
- W_Aarray or None, shape (n, d_A)
Propensity-score covariate matrix. If
None,Xis used.- batch_sizeint
Perturbations per batch (default 10). Ignored when
n_batchesis set. Batches are sized evenly withnumpy.array_split()so the last batch is never drastically smaller than the others.- n_batchesint or None
Total number of batches. When set, overrides
batch_sizeand perturbations are split as evenly as possible across exactlyn_batchesbatches (e.g.n_batches=2on a 29-pert dataset gives two batches of 15 and 14).- max_cellsint or None
Maximum pert cells per batch (default 2 000).
Nonemeans no cap. Ctrl cells are added on top so the actual batch size is at mostn_ctrl + max_cells. The cap is rarely active because typical Perturb-seq datasets have only a few hundred cells per perturbation.- n_ctrlint
Number of ctrl cells in the fixed subsample (default 2 000).
- familystr
GLM family (default
'nb').- offsetbool or array-like
Offset specification passed to
fit_gcate_batch().- warm_start_Ubool
Passed to
fit_gcate_batch().- cache_pathstr or None
Path to an HDF5 file used for incremental caching. When set:
On entry, any already-computed batches are loaded from the store and their indices are skipped by
fit_gcate_batch().After each new batch, the result DataFrame is appended to the store under key
/batch_{i:04d}.On exit, all batches (cached + newly computed) are concatenated and returned.
This lets you resume an interrupted run by re-calling the function with the same
cache_path— completed batches are not re-run.- random_stateint
RNG seed.
- verbosebool
Print per-batch timing.
- gcate_kwargsdict or None
Extra keyword arguments forwarded to
fit_gcate_batch()(and ultimatelyfit_gcate()). E.g.:gcate_kwargs=dict(backend='fast', kwargs_es_1=dict(max_iters=10, rel_tol=2e-4), kwargs_es_2=dict(max_iters=10, rel_tol=2e-4))- lfc_kwargsdict or None
Extra keyword arguments forwarded to
LFC()(e.g.usevar,fdx,thres_min).- **kwargs
Additional arguments forwarded to both
fit_gcate_batch()andLFC(). When a key collides withgcate_kwargs/lfc_kwargs, the stage-specific dict wins — this lets you scope a kwarg to one stage (e.g.gcate_kwargs=dict( backend='fast')paired with a top-levelbackend='original'targeting LFC).
- Returns:
- df_resDataFrame
Concatenated result from all batches. Includes a
'batch'column with the 0-based batch index so batches can be identified.