Generalized confounder estimation#
The GCATE (Generalized Confounder Adjustment for Treatment Effects) functions
estimate latent factors that capture unmeasured confounders in the data.
Call estimate_r() first to select the number of factors r via the JIC
criterion, then call fit_gcate() (or fit_gcate_batch() for
large-scale screens) to obtain the latent factor matrix U. Append U
to the observed covariate matrix before calling LFC().
- causarray.gcate.estimate(Y, X, r, a, lam1, kwargs_glm, kwargs_ls_1, kwargs_es_1, kwargs_ls_2, kwargs_es_2, A_init=None, **kwargs)#
Two-stage alternating minimization for the GCATE model (internal).
- Parameters:
- Yarray-like, shape (n, p)
Response matrix.
- Xarray-like, shape (n, d+a)
Observed covariate matrix (covariates and treatments concatenated).
- rint
Number of latent variables.
- aint
Number of treatment columns (last
acolumns ofX).- lam1float
Regularization parameter for the first stage.
- kwargs_glmdict
GLM configuration (family, dispersion, size factors).
- kwargs_ls_1dict
Line search arguments for the first stage.
- kwargs_ls_2dict
Line search arguments for the second stage.
- kwargs_es_1dict
Early-stopping arguments for the first stage.
- kwargs_es_2dict
Early-stopping arguments for the second stage.
- Returns:
- res_1dict
First-stage optimisation results.
- res_2dict
Second-stage results with keys
'X_U'(n, d+a+r) and'B_Gamma'(p, d+a+r).
- causarray.gcate.estimate_r(Y, X, A, r_max, c=1.0, family='nb', disp_glm=None, disp_family='poisson', offset=True, max_cells=None, random_state=0, kwargs_ls_1={}, kwargs_ls_2={}, kwargs_es_1={}, kwargs_es_2={}, **kwargs)#
Estimate the number of latent factors for the GCATE model.
Fits GCATE for each candidate value in
r_maxand selects the number of factors that minimises the JIC (joint information criterion), a penalised-likelihood criterion analogous to BIC.- Parameters:
- Yarray-like, shape (n, p)
Response matrix.
- Xarray-like, shape (n, d)
Observed covariate matrix.
- Aarray-like, shape (n, a)
Treatment matrix.
- r_maxint
Number of latent variables.
- cfloat
The constant factor for the complexity term.
- familystr
The family of the GLM. Default is ‘poisson’.
- disp_glmarray-like, shape (1, p) or None
The dispersion parameter for the negative binomial distribution.
- max_cellsint or None
Maximum number of cells to use for estimation. When the dataset exceeds this size, a stratified subsample is drawn automatically: a floor of
max_cells // 4slots is reserved for treated cells (or all of them iflen(pert_idx) < max_cells // 4) so that perturbation-induced latent variation remains visible to the JIC, and the remaining budget goes to controls (with leftovers spilling back to treated cells if controls are themselves scarce). This is especially useful for large batch-fitting workflows wherenis in the tens of thousands; confounding structure (captured byr) is concentrated in the baseline transcriptome, so a ctrl-priority subsample with a treated-cell floor is both faster and statistically principled.None(default) uses all cells.- random_stateint
RNG seed for subsampling (only used when
max_cellsis set).- kwargs_ls_1dict
Keyword arguments for the line search solver in the first stage.
- kwargs_ls_2dict
Keyword arguments for the line search solver in the second stage.
- kwargs_es_1dict
Keyword arguments for the early stopper in the first stage.
- kwargs_es_2dict
Keyword arguments for the early stopper in the second stage.
- Returns:
- df_rDataFrame
DataFrame with columns
r,deviance,nu,JIC, sorted byr. The optimalrminimises theJICcolumn.
- causarray.gcate.fit_gcate(Y, X, A, r, family='nb', disp_glm=None, disp_family=None, offset=True, kwargs_ls_1={}, kwargs_ls_2={}, kwargs_es_1={}, kwargs_es_2={}, c1=None, backend: str = 'auto', A_init=None, **kwargs)#
Fit the GCATE model to estimate unmeasured confounders.
Runs two-stage alternating minimization to jointly estimate latent factor loadings
U(n×r) and gene-level coefficientsB. The estimated latent factors should then be appended to the covariate matrix before callingLFC().- Parameters:
- Yarray-like, shape (n, p)
Count matrix of outcomes.
- Xarray-like, shape (n, d)
Observed covariate matrix (intercept should be included).
- Aarray-like, shape (n, a)
Binary treatment indicator matrix.
- rint
Number of unmeasured confounders (latent factors) to estimate. Use
estimate_r()to select this value via the JIC criterion.- familystr
GLM family:
'nb'(default, negative binomial) or'poisson'.- disp_glmarray-like, shape (p,) or None
Dispersion parameters for the NB family. Estimated automatically when
Noneandfamily='nb'.- disp_familystr or None
Family used for internal dispersion estimation (default
'poisson').- offsetbool or array-like
Log-scale offset.
Truecomputes size factors automatically;FalseorNonedisables the offset.- kwargs_ls_1dict
Keyword arguments for the line search solver in the first stage.
- kwargs_ls_2dict
Keyword arguments for the line search solver in the second stage.
- kwargs_es_1dict
Keyword arguments for the early stopper in the first stage.
- kwargs_es_2dict
Keyword arguments for the early stopper in the second stage.
- c1float
Regularization constant for the first stage. Default is 0.05.
- backendstr
GLM backend:
"auto"(default),"fast"(force crispyx), or"original"(force statsmodels).- A_initarray-like, shape (n, d + a + r) or None
Optional warm-start matrix
[X | A | U]for the first stage. When provided, the SVD-based initialisation is skipped.- **kwargs
Additional keyword arguments forwarded to the GLM fitting functions.
- Returns:
- res_1dict
Results of the first optimization stage.
- res_2dict
Results of the second optimization stage. Key entries:
'X_U'array, shape (n, d + a + r)Augmented covariate matrix
[X | A | U].'B_Gamma'array, shape (p, d + a + r)Fitted gene-level coefficient matrix.
Pass
res_2['U'] = res_2['X_U'][:, d+a:]as the latent factor block when constructingWforLFC().
- causarray.gcate.fit_gcate_batch(Y, X, A, r, batch_size=10, n_batches=None, max_cells=2000, n_ctrl=2000, family='nb', disp_glm=None, disp_family=None, offset=True, warm_start_U=False, skip_batches=None, random_state=0, verbose=False, **kwargs)#
Fit GCATE independently on batches of perturbations.
Partitions the
aperturbations into chunks ofbatch_sizeand for each chunk selects the fixed ctrl subsample plus (a subset of) the treated cells, capped atmax_cellstotal. Dispersion is pre-estimated once on the ctrl cell pool so all batches share the same nuisance parameters.- Parameters:
- Yarray-like, DataFrame, or scipy.sparse, shape (n, p)
Count matrix. Sparse inputs are densified to
float64once at the start of the function; aResourceWarningis emitted when the dense materialisation would exceed ~4 GB.- Xarray, shape (n, d)
Covariate matrix (intercept should be included).
- Aarray-like or DataFrame, shape (n, a)
Binary treatment indicator matrix; control cells have all-zero rows. Single-perturbation-per-cell is assumed: the batch loop treats any cell whose only active perturbation falls outside the current chunk as a control within that batch, which silently contaminates the within-batch null for combinatorial designs. When such rows are detected a
RuntimeWarningis emitted at function entry — passn_batches=1or pre-filter multi-pert rows to suppress it.- rint
Number of latent factors.
- batch_sizeint
Perturbations per batch (default 10). Ignored when
n_batchesis set. Batches are sized evenly usingnumpy.array_split(), so the last batch is never more than one perturbation smaller than the others.- n_batchesint or None
Total number of batches. When set, overrides
batch_sizeand the perturbations are split as evenly as possible across exactlyn_batchesbatches. Useful when you want to control the number of batches rather than the per-batch perturbation count (e.g.n_batches=2splits a 29-pert dataset into two batches of 15/14).- max_cellsint or None
Maximum pert cells per batch (default 2 000).
Nonemeans no cap — all pert cells are used. The ctrl pool is added on top, so the actual batch size is at mostn_ctrl + max_cells. 2 000 is a safe default because most Perturb-seq datasets have only a few hundred cells per perturbation, so the cap is rarely active.- n_ctrlint
Number of ctrl cells in the fixed subsample shared across batches (default 2 000).
- familystr
GLM family,
'nb'or'poisson'(default'nb').- disp_glmarray or None
Dispersion parameter
(p,). IfNoneandfamily='nb', estimated once on the ctrl cell subsample before the batch loop.- disp_familystr or None
Passed to
fit_gcate(used only whendisp_glmisNoneand the internal estimation path is taken inside each batch).- offsetbool or array-like
Offset specification passed to
fit_gcate.- warm_start_Ubool
If
True, initialises U rows for ctrl cells in batchi+1from the latent factors estimated in batchi. Requires the same ctrl indices in every batch (guaranteed by design).- skip_batchesset of int or None
Batch indices (0-based) to skip entirely. Used by
gcate_lfc_batch()to avoid re-running GCATE for batches whose LFC results are already cached on disk. Skipped batches still appear in the returned list with'skipped': Trueandres_1/res_2 = None.- random_stateint
Base RNG seed; each batch uses
random_state + batch_ifor pert subsampling to avoid drawing the same cells repeatedly.- verbosebool
Print per-batch timing and progress.
- **kwargs
Forwarded to
fit_gcate()(e.g.backend,kwargs_es_1).
- Returns:
- batch_resultslist of dict
One dict per batch with keys:
'batch_i'Batch index (0-based).
'pert_names'List of perturbation names in this batch.
'ctrl_idx'Global indices of ctrl cells (same for all batches).
'pert_idx'Global indices of pert cells used in this batch.
'cell_idx'Sorted union of ctrl and pert indices.
'res_1','res_2'GCATE optimisation results (dicts from
fit_gcate()).'disp_glm'Shared dispersion array
(p,)used for this batch.'t_batch'Wall-clock seconds for this batch.
'skipped'Truewhen the batch was inskip_batches; absent otherwise.