Doubly-robust semiparametric inference

Doubly-robust semiparametric inference#

LFC() estimates per-gene log-fold changes using a doubly-robust AIPW estimator. It requires the augmented covariate matrix W = [X | U] where U are the latent factors from fit_gcate(), and produces a DataFrame with effect estimates, standard errors, and BH-adjusted p-values.

For screens with hundreds of perturbations use gcate_lfc_batch(), which runs GCATE and LFC in batches to keep peak memory bounded.

causarray.DR_learner.LFC(Y, W, A, W_A=None, family='nb', offset=False, Y_hat=None, pi_hat=None, cross_est=False, mask=None, usevar='unequal', thres_min=0.01, thres_diff=0.01, eps_var=0.0001, fdx=False, fdx_alpha=0.05, fdx_c=0.1, verbose=False, backend: str = 'auto', **kwargs)#

Estimate log-fold changes of treatment effects (LFCs) using AIPW.

Fits a doubly-robust AIPW estimator for the log-ratio of counterfactual means E[Y(1)] / E[Y(0)]. Call this after fit_gcate() to incorporate estimated latent factors into the covariate matrix W.

Parameters:
Yarray, shape (n, p)

Count matrix of outcomes.

Warray, shape (n, d)

Covariate matrix, typically [X | U] where U are the latent factors from GCATE.

Aarray, shape (n, a)

Binary treatment indicator matrix.

W_Aarray or None, shape (n, d_A)

Covariate matrix for the propensity model. If None, W is used.

familystr

GLM family for the outcome model: 'nb' (default) or 'poisson'.

offsetbool or array-like

Log-scale offset for the outcome model. True computes size factors automatically; False or None disables the offset.

Y_hatarray or None, shape (n, p, a, 2)

Pre-computed counterfactual predictions. When provided, cross-fitting is skipped.

pi_hatarray or None, shape (n, a)

Pre-computed propensity scores. When provided, propensity fitting is skipped.

cross_estbool

Whether to use cross-estimation for nuisance parameters.

maskarray or None, shape (n, a)

Boolean mask indicating which cells are used for the final estimand computation. Does not affect cross-fitting or propensity estimation.

usevarstr

Variance estimator for the AIPW pseudo-outcomes:

  • 'unequal' (default, v0.0.6+): Welch variance s₀²/n₀ + s₁²/n₁ with Welch-Satterthwaite degrees of freedom; p-values use the t-distribution.

  • 'pooled': pooled-variance estimator (s² + eps_var) / n.

Changed in version 0.0.6: Default changed from 'pooled' to 'unequal'. The 'unequal' formula was also corrected from (s₀²/n₀ + s₁²/n₁)/2 to the standard Welch form, which shrinks t-statistics by ≈ √2 relative to v0.0.5. Pass usevar='pooled' to recover pre-v0.0.6 behaviour.

thres_minfloat

Genes whose maximum counterfactual mean is below this threshold are excluded (reported as tau=0, padj=NaN).

thres_difffloat

Genes whose counterfactual means differ by less than this value are excluded.

eps_varfloat

Small constant added to per-arm variances to prevent division by zero.

fdxbool

Whether to apply FDX control (P(FDP > fdx_c) < fdx_alpha).

fdx_alphafloat

Significance level for FDX control.

fdx_cfloat

FDP threshold for FDX control.

verbosebool

Print progress information.

backendstr

GLM backend: "auto" (default), "fast" (force crispyx), or "original" (force statsmodels).

**kwargs

Additional arguments forwarded to the GLM fitting functions.

Returns:
df_resDataFrame

Test results with columns gene_names, tau, std, stat, rej, pvalue, padj, pvalue_emp_null_adj, padj_emp_null_adj (and trt when A has multiple columns).

causarray.DR_learner.LFC_batch(*args, **kwargs)#

Deprecated alias for gcate_lfc_batch().

Deprecated since version Use: gcate_lfc_batch instead. LFC_batch will be removed in a future release.

causarray.DR_learner.VIM(eta_est, X, id_covs, **kwargs)#

Estimate variable importance measures (VIM) for heterogeneous treatment effects.

Decomposes treatment effect variance into components explained by each covariate using conditional average treatment effect (CATE) regression.

Parameters:
eta_estarray, shape (n, p)

Influence function values from LFC() or compute_causal_estimand().

Xarray, shape (n, d)

Covariate matrix.

id_covsint or array-like of int

Column indices of X to compute VIM for. An integer k is treated as range(k).

Returns:
estimationdict

Dictionary with keys:

'CATE', 'CATE_lower', 'CATE_upper'array, shape (n_covs, n, p)

Conditional average treatment effect and pointwise confidence band.

'VTE'array, shape (p,)

Total variance of the treatment effect (marginal).

'CVTE'array, shape (n_covs, p)

Conditional variance of the treatment effect given each covariate.

'VIM_mean'array, shape (n_covs, p)

VIM point estimate CVTE / VTE - 1 for each covariate and gene.

'VIM_sd'array, shape (n_covs, p)

Standard deviation of the VIM estimate.

causarray.DR_learner.compute_causal_estimand(estimand, Y, W, A, W_A=None, family='nb', offset=False, Y_hat=None, pi_hat=None, mask=None, fdx=False, fdx_B=1000, fdx_alpha=0.05, fdx_c=0.1, verbose=False, random_state=0, backend: str = 'auto', **kwargs)#

Estimate causal treatment effects using AIPW with a user-supplied estimand.

Parameters:
estimandcallable

Function that maps influence function values (etas, A) to (eta_est, tau_est, var_est[, df_eff]). See LFC() for an example implementation.

Yarray, shape (n, p)

Count matrix of outcomes.

Warray, shape (n, d)

Covariate matrix (including latent factors from GCATE).

Aarray, shape (n, a)

Binary treatment indicator matrix.

W_Aarray or None, shape (n, d_A)

Covariate matrix for the propensity model. If None, W is used.

familystr

GLM family for the outcome model: 'nb' (default) or 'poisson'.

offsetbool or array-like

Log-scale offset for the outcome model. True computes size factors automatically; False or None disables the offset.

Y_hatarray or None, shape (n, p, a, 2)

Pre-computed counterfactual predictions. When provided, cross-fitting is skipped.

pi_hatarray or None, shape (n, a)

Pre-computed propensity scores. When provided, propensity fitting is skipped.

maskarray or None, shape (n, a)

Boolean mask indicating which cells are used for the final estimand computation. Does not affect cross-fitting or propensity estimation.

fdxbool

Whether to apply FDX control (P(FDP > fdx_c) < fdx_alpha).

fdx_Bint

Number of bootstrap samples for FDX control.

fdx_alphafloat

Significance level for FDX control.

fdx_cfloat

FDP threshold for FDX control.

backendstr

GLM backend: "auto" (default), "fast" (force crispyx), or "original" (force statsmodels).

verbosebool

Print progress information.

**kwargs

Additional arguments forwarded to the GLM fitting functions.

Returns:
df_resDataFrame

Test results with columns gene_names, tau, std, stat, rej, pvalue, padj, pvalue_emp_null_adj, padj_emp_null_adj (and trt when A has multiple columns).

causarray.DR_learner.gcate_lfc_batch(Y, X, A, r, W_A=None, batch_size=10, n_batches=None, max_cells=2000, n_ctrl=2000, family='nb', offset=True, warm_start_U=False, cache_path=None, random_state=0, verbose=False, gcate_kwargs=None, lfc_kwargs=None, **kwargs)#

Batch-wise GCATE + doubly-robust LFC estimation.

Partitions perturbations into chunks of batch_size, runs fit_gcate_batch() to estimate per-batch latent confounders, then calls LFC() on each batch independently. All large intermediate arrays (res_1, res_2, Y_hat, pi_hat) are freed immediately after each batch so that peak memory is bounded by one batch’s worth of data regardless of the total number of perturbations.

Results can optionally be cached to an HDF5 file (cache_path) so that interrupted runs can be resumed without re-processing completed batches.

Parameters:
Yarray-like or DataFrame, shape (n, p)

Count matrix.

Xarray, shape (n, d)

Covariate matrix (intercept column should be included).

Aarray-like or DataFrame, shape (n, a)

Binary treatment indicator matrix; control cells have all-zero rows.

rint

Number of latent factors.

W_Aarray or None, shape (n, d_A)

Propensity-score covariate matrix. If None, X is used.

batch_sizeint

Perturbations per batch (default 10). Ignored when n_batches is set. Batches are sized evenly with numpy.array_split() so the last batch is never drastically smaller than the others.

n_batchesint or None

Total number of batches. When set, overrides batch_size and perturbations are split as evenly as possible across exactly n_batches batches (e.g. n_batches=2 on a 29-pert dataset gives two batches of 15 and 14).

max_cellsint or None

Maximum pert cells per batch (default 2 000). None means no cap. Ctrl cells are added on top so the actual batch size is at most n_ctrl + max_cells. The cap is rarely active because typical Perturb-seq datasets have only a few hundred cells per perturbation.

n_ctrlint

Number of ctrl cells in the fixed subsample (default 2 000).

familystr

GLM family (default 'nb').

offsetbool or array-like

Offset specification passed to fit_gcate_batch().

warm_start_Ubool

Passed to fit_gcate_batch().

cache_pathstr or None

Path to an HDF5 file used for incremental caching. When set:

  • On entry, any already-computed batches are loaded from the store and their indices are skipped by fit_gcate_batch().

  • After each new batch, the result DataFrame is appended to the store under key /batch_{i:04d}.

  • On exit, all batches (cached + newly computed) are concatenated and returned.

This lets you resume an interrupted run by re-calling the function with the same cache_path — completed batches are not re-run.

random_stateint

RNG seed.

verbosebool

Print per-batch timing.

gcate_kwargsdict or None

Extra keyword arguments forwarded to fit_gcate_batch() (and ultimately fit_gcate()). E.g.:

gcate_kwargs=dict(backend='fast',
                  kwargs_es_1=dict(max_iters=10, rel_tol=2e-4),
                  kwargs_es_2=dict(max_iters=10, rel_tol=2e-4))
lfc_kwargsdict or None

Extra keyword arguments forwarded to LFC() (e.g. usevar, fdx, thres_min).

**kwargs

Additional arguments forwarded to both fit_gcate_batch() and LFC(). When a key collides with gcate_kwargs / lfc_kwargs, the stage-specific dict wins — this lets you scope a kwarg to one stage (e.g. gcate_kwargs=dict( backend='fast') paired with a top-level backend='original' targeting LFC).

Returns:
df_resDataFrame

Concatenated result from all batches. Includes a 'batch' column with the 0-based batch index so batches can be identified.