Generalized confounder estimation

Generalized confounder estimation#

The GCATE (Generalized Confounder Adjustment for Treatment Effects) functions estimate latent factors that capture unmeasured confounders in the data. Call estimate_r() first to select the number of factors r via the JIC criterion, then call fit_gcate() (or fit_gcate_batch() for large-scale screens) to obtain the latent factor matrix U. Append U to the observed covariate matrix before calling LFC().

causarray.gcate.estimate(Y, X, r, a, lam1, kwargs_glm, kwargs_ls_1, kwargs_es_1, kwargs_ls_2, kwargs_es_2, A_init=None, **kwargs)#

Two-stage alternating minimization for the GCATE model (internal).

Parameters:
Yarray-like, shape (n, p)

Response matrix.

Xarray-like, shape (n, d+a)

Observed covariate matrix (covariates and treatments concatenated).

rint

Number of latent variables.

aint

Number of treatment columns (last a columns of X).

lam1float

Regularization parameter for the first stage.

kwargs_glmdict

GLM configuration (family, dispersion, size factors).

kwargs_ls_1dict

Line search arguments for the first stage.

kwargs_ls_2dict

Line search arguments for the second stage.

kwargs_es_1dict

Early-stopping arguments for the first stage.

kwargs_es_2dict

Early-stopping arguments for the second stage.

Returns:
res_1dict

First-stage optimisation results.

res_2dict

Second-stage results with keys 'X_U' (n, d+a+r) and 'B_Gamma' (p, d+a+r).

causarray.gcate.estimate_r(Y, X, A, r_max, c=1.0, family='nb', disp_glm=None, disp_family='poisson', offset=True, max_cells=None, random_state=0, kwargs_ls_1={}, kwargs_ls_2={}, kwargs_es_1={}, kwargs_es_2={}, **kwargs)#

Estimate the number of latent factors for the GCATE model.

Fits GCATE for each candidate value in r_max and selects the number of factors that minimises the JIC (joint information criterion), a penalised-likelihood criterion analogous to BIC.

Parameters:
Yarray-like, shape (n, p)

Response matrix.

Xarray-like, shape (n, d)

Observed covariate matrix.

Aarray-like, shape (n, a)

Treatment matrix.

r_maxint

Number of latent variables.

cfloat

The constant factor for the complexity term.

familystr

The family of the GLM. Default is ‘poisson’.

disp_glmarray-like, shape (1, p) or None

The dispersion parameter for the negative binomial distribution.

max_cellsint or None

Maximum number of cells to use for estimation. When the dataset exceeds this size, a stratified subsample is drawn automatically: a floor of max_cells // 4 slots is reserved for treated cells (or all of them if len(pert_idx) < max_cells // 4) so that perturbation-induced latent variation remains visible to the JIC, and the remaining budget goes to controls (with leftovers spilling back to treated cells if controls are themselves scarce). This is especially useful for large batch-fitting workflows where n is in the tens of thousands; confounding structure (captured by r) is concentrated in the baseline transcriptome, so a ctrl-priority subsample with a treated-cell floor is both faster and statistically principled. None (default) uses all cells.

random_stateint

RNG seed for subsampling (only used when max_cells is set).

kwargs_ls_1dict

Keyword arguments for the line search solver in the first stage.

kwargs_ls_2dict

Keyword arguments for the line search solver in the second stage.

kwargs_es_1dict

Keyword arguments for the early stopper in the first stage.

kwargs_es_2dict

Keyword arguments for the early stopper in the second stage.

Returns:
df_rDataFrame

DataFrame with columns r, deviance, nu, JIC, sorted by r. The optimal r minimises the JIC column.

causarray.gcate.fit_gcate(Y, X, A, r, family='nb', disp_glm=None, disp_family=None, offset=True, kwargs_ls_1={}, kwargs_ls_2={}, kwargs_es_1={}, kwargs_es_2={}, c1=None, backend: str = 'auto', A_init=None, **kwargs)#

Fit the GCATE model to estimate unmeasured confounders.

Runs two-stage alternating minimization to jointly estimate latent factor loadings U (n×r) and gene-level coefficients B. The estimated latent factors should then be appended to the covariate matrix before calling LFC().

Parameters:
Yarray-like, shape (n, p)

Count matrix of outcomes.

Xarray-like, shape (n, d)

Observed covariate matrix (intercept should be included).

Aarray-like, shape (n, a)

Binary treatment indicator matrix.

rint

Number of unmeasured confounders (latent factors) to estimate. Use estimate_r() to select this value via the JIC criterion.

familystr

GLM family: 'nb' (default, negative binomial) or 'poisson'.

disp_glmarray-like, shape (p,) or None

Dispersion parameters for the NB family. Estimated automatically when None and family='nb'.

disp_familystr or None

Family used for internal dispersion estimation (default 'poisson').

offsetbool or array-like

Log-scale offset. True computes size factors automatically; False or None disables the offset.

kwargs_ls_1dict

Keyword arguments for the line search solver in the first stage.

kwargs_ls_2dict

Keyword arguments for the line search solver in the second stage.

kwargs_es_1dict

Keyword arguments for the early stopper in the first stage.

kwargs_es_2dict

Keyword arguments for the early stopper in the second stage.

c1float

Regularization constant for the first stage. Default is 0.05.

backendstr

GLM backend: "auto" (default), "fast" (force crispyx), or "original" (force statsmodels).

A_initarray-like, shape (n, d + a + r) or None

Optional warm-start matrix [X | A | U] for the first stage. When provided, the SVD-based initialisation is skipped.

**kwargs

Additional keyword arguments forwarded to the GLM fitting functions.

Returns:
res_1dict

Results of the first optimization stage.

res_2dict

Results of the second optimization stage. Key entries:

'X_U'array, shape (n, d + a + r)

Augmented covariate matrix [X | A | U].

'B_Gamma'array, shape (p, d + a + r)

Fitted gene-level coefficient matrix.

Pass res_2['U'] = res_2['X_U'][:, d+a:] as the latent factor block when constructing W for LFC().

causarray.gcate.fit_gcate_batch(Y, X, A, r, batch_size=10, n_batches=None, max_cells=2000, n_ctrl=2000, family='nb', disp_glm=None, disp_family=None, offset=True, warm_start_U=False, skip_batches=None, random_state=0, verbose=False, **kwargs)#

Fit GCATE independently on batches of perturbations.

Partitions the a perturbations into chunks of batch_size and for each chunk selects the fixed ctrl subsample plus (a subset of) the treated cells, capped at max_cells total. Dispersion is pre-estimated once on the ctrl cell pool so all batches share the same nuisance parameters.

Parameters:
Yarray-like, DataFrame, or scipy.sparse, shape (n, p)

Count matrix. Sparse inputs are densified to float64 once at the start of the function; a ResourceWarning is emitted when the dense materialisation would exceed ~4 GB.

Xarray, shape (n, d)

Covariate matrix (intercept should be included).

Aarray-like or DataFrame, shape (n, a)

Binary treatment indicator matrix; control cells have all-zero rows. Single-perturbation-per-cell is assumed: the batch loop treats any cell whose only active perturbation falls outside the current chunk as a control within that batch, which silently contaminates the within-batch null for combinatorial designs. When such rows are detected a RuntimeWarning is emitted at function entry — pass n_batches=1 or pre-filter multi-pert rows to suppress it.

rint

Number of latent factors.

batch_sizeint

Perturbations per batch (default 10). Ignored when n_batches is set. Batches are sized evenly using numpy.array_split(), so the last batch is never more than one perturbation smaller than the others.

n_batchesint or None

Total number of batches. When set, overrides batch_size and the perturbations are split as evenly as possible across exactly n_batches batches. Useful when you want to control the number of batches rather than the per-batch perturbation count (e.g. n_batches=2 splits a 29-pert dataset into two batches of 15/14).

max_cellsint or None

Maximum pert cells per batch (default 2 000). None means no cap — all pert cells are used. The ctrl pool is added on top, so the actual batch size is at most n_ctrl + max_cells. 2 000 is a safe default because most Perturb-seq datasets have only a few hundred cells per perturbation, so the cap is rarely active.

n_ctrlint

Number of ctrl cells in the fixed subsample shared across batches (default 2 000).

familystr

GLM family, 'nb' or 'poisson' (default 'nb').

disp_glmarray or None

Dispersion parameter (p,). If None and family='nb', estimated once on the ctrl cell subsample before the batch loop.

disp_familystr or None

Passed to fit_gcate (used only when disp_glm is None and the internal estimation path is taken inside each batch).

offsetbool or array-like

Offset specification passed to fit_gcate.

warm_start_Ubool

If True, initialises U rows for ctrl cells in batch i+1 from the latent factors estimated in batch i. Requires the same ctrl indices in every batch (guaranteed by design).

skip_batchesset of int or None

Batch indices (0-based) to skip entirely. Used by gcate_lfc_batch() to avoid re-running GCATE for batches whose LFC results are already cached on disk. Skipped batches still appear in the returned list with 'skipped': True and res_1/res_2 = None.

random_stateint

Base RNG seed; each batch uses random_state + batch_i for pert subsampling to avoid drawing the same cells repeatedly.

verbosebool

Print per-batch timing and progress.

**kwargs

Forwarded to fit_gcate() (e.g. backend, kwargs_es_1).

Returns:
batch_resultslist of dict

One dict per batch with keys:

'batch_i'

Batch index (0-based).

'pert_names'

List of perturbation names in this batch.

'ctrl_idx'

Global indices of ctrl cells (same for all batches).

'pert_idx'

Global indices of pert cells used in this batch.

'cell_idx'

Sorted union of ctrl and pert indices.

'res_1', 'res_2'

GCATE optimisation results (dicts from fit_gcate()).

'disp_glm'

Shared dispersion array (p,) used for this batch.

't_batch'

Wall-clock seconds for this batch.

'skipped'

True when the batch was in skip_batches; absent otherwise.