import os
import torch
import numpy as np
import pandas as pd
import scanpy as sc
import anndata
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import seaborn as sns
from . import nicheformer as nf
from . import workflow as wl
from . import prob_nmfae as pnf
from . import utils
import importlib
import scipy
[docs]
def calculate_wb_ez(adata, model_path, batch_key=None, neighbor_num=100, latent_dim=20, cellrep_key='X_pca'):
"""
Calculate the embeddings required for the score function and add them to the AnnData object.
The score function is defined as:
.. math::
s_{\\theta}(e_i, z_j) = w_e(e_i)^\\top w_z(z_j) + b_z(z_j)
where :math:`w_e` and :math:`w_z` are neural networks mapping microenvironmental and cell-state
embeddings to a shared hidden dimension, and :math:`b_z` provides a cell-state-dependent bias.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix.
model_path : str
Path to the trained model checkpoint.
batch_key : str, optional
Key in `adata.obs` indicating batch information. Default is None.
neighbor_num : int, optional
Number of neighbors used during training. Default is 100.
latent_dim : int, optional
Latent dimension of the model. Default is 20.
cellrep_key : str, optional
Key in `adata.obsm` containing the cell state representations. Default is 'X_pca'.
Returns
-------
anndata.AnnData
The input AnnData object updated with:
- `obsm['w_e']`: Projected microenvironmental embeddings.
- `obsm['w_z']`: Projected cell state embeddings.
- `obsm['b_z']`: Cell state bias terms.
- `obsm['e']`: Microenvironmental embeddings (if not already present).
"""
model_id_dict = {
'ldim': latent_dim,
'nn': neighbor_num,
'crkey': cellrep_key,
'bkey': batch_key
}
model = wl.loading_pre_trained_model(model_path, adata, model_id_dict)
if cellrep_key == 'X':
adata.obsm['nf_cellrep'] = adata.X.toarray() if not isinstance(adata.X, np.ndarray) else adata.X
else:
adata.obsm['nf_cellrep'] = adata.obsm[cellrep_key]
if 'e' not in adata.obsm:
ds = wl.adata2ds(adata, neighbor_num=neighbor_num, batch_key=batch_key)
e, mu, sigma = wl.output_dist_params(ds, model)
adata.obsm['e'] = e.numpy()
wl.add_wb_ez(adata, model, cell_rep_key='nf_cellrep')
return adata
[docs]
def calculate_niche_density_ratio(adata, ref_num=1000, stratify_key='leiden_e', min_ratio=0.01, ref_adata=None):
"""
Compute per-cell density ratios over a panel of reference niches.
For each cell :math:`i` and reference niche :math:`j` drawn by stratified
sampling on ``stratify_key``, the log density ratio is
.. math::
\\log r_{ij}
= \\log p(e_j \\mid z_i) - \\log p(e_j)
= (w_z(z_i)^\\top w_e(e_j) + b_z(z_i))
- \\log \\sum_{k \\in \\mathrm{ref}}
\\exp(w_z(z_k)^\\top w_e(e_j) + b_z(z_k)).
The matrix is then softmax-normalized per cell over reference niches,
so each row of ``adata.obsm['dist_e']`` is a probability distribution
over the sampled reference niches that emphasizes niches whose
environment becomes more likely under the cell's state than under the
marginal.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix containing ``w_e``, ``w_z``, and ``b_z`` in ``obsm``
(produced by :func:`calculate_wb_ez`).
ref_num : int, optional
Number of reference niches to sample. Default is 1000.
stratify_key : str, optional
Key in ``adata.obs`` to use for stratified sampling of reference niches.
Default is 'leiden_e'.
min_ratio : float, optional
Clusters with frequency below this fraction are dropped from stratified
sampling. Default is 0.01.
ref_adata : anndata.AnnData, optional
External reference. If ``None``, a subset of ``adata`` is used.
Returns
-------
anndata.AnnData
Updated with ``obsm['dist_e']`` (softmax-normalized density ratios of
shape ``(n_cells, ref_num)``) and ``uns['dist_e']['ref_obs']`` (obs
names of the sampled reference niches). The ``dist_e`` key name is
preserved for backward compatibility with existing h5ad artifacts.
"""
wl.calculate_niche_density_ratio(adata, ref_niche_num=ref_num, stratify_key=stratify_key, min_ratio=min_ratio, ref_adata=ref_adata)
return adata
[docs]
def calculate_niche_cluster_membership(adata, cluster_key='leiden_e'):
"""
Aggregate per-cell density ratios into a soft membership over niche clusters.
Averages the columns of ``adata.obsm['dist_e']`` within each value of
``adata.obs[cluster_key]`` (typically ``leiden_e`` niche clusters),
yielding ``adata.obsm['dist_e_agg']`` of shape
``(n_cells, n_niche_clusters)``: entry ``[i, c]`` is the mean density
ratio :math:`p(e \\mid z_i)/p(e)` evaluated at reference cells in cluster
``c``, interpretable as a soft assignment of cell ``i`` to niche cluster ``c``.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix containing ``obsm['dist_e']`` (see
:func:`calculate_niche_density_ratio`). If absent, it is computed
with defaults.
cluster_key : str, optional
Key in ``adata.obs`` containing niche cluster labels. Default is 'leiden_e'.
Returns
-------
anndata.AnnData
Updated with ``obsm['dist_e_agg']``: per-cell niche-cluster membership
(columns are niche cluster labels). The ``dist_e_agg`` key name is
preserved for backward compatibility with existing h5ad artifacts
used by figure scripts.
"""
wl.calculate_niche_cluster_membership(adata, group_key=cluster_key)
return adata
[docs]
def estimate_population_density(adata, group, cluster_key, max_cell_num=1000):
"""
Estimate the density (existence probability) of a specific cell population in each microenvironment.
By integrating :math:`P(z|e)` over all cell states belonging to a specific cell population,
this function obtains the density of that population in microenvironment :math:`e`.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix.
group : str
The label of the cell population (e.g., a specific cell type) to estimate density for.
cluster_key : str
Key in `adata.obs` containing the cell type/cluster labels.
max_cell_num : int, optional
Maximum number of cells to sample from the group for density estimation.
Default is 1000.
Returns
-------
anndata.AnnData
The input AnnData object updated with a new column in `obs` (e.g., `{group}_density`)
representing the estimated density of the specified population for each cell's microenvironment.
"""
wl.estimate_population_density(adata, group, cluster_key, max_cell_num)
return adata
[docs]
def analyze_density_correlation(adata, density_col, gene_list=None, file_path=None):
"""
Analyze the correlation between estimated cell population density and gene expression.
This analysis helps identify gene expression signatures associated with colocalization
with specific cell populations. For example, identifying genes upregulated in tumor cells
when they colocalize with endothelial cells.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix containing expression data and the density column.
density_col : str
Name of the column in `adata.obs` containing the estimated density values.
gene_list : list of str, optional
List of genes to include in the correlation analysis. If None, uses all genes in `adata.var_names`.
file_path : str, optional
Path to save the visualization plot (bar plot of top/bottom correlated genes).
If None, the plot is not saved.
Returns
-------
pandas.Series
A Series containing the correlation coefficients for each gene, indexed by gene name.
"""
if gene_list is None:
gene_list = adata.var_names
density = adata.obs[density_col].values
# Ensure density is numeric
density = pd.to_numeric(density, errors='coerce')
# Check if X is sparse
X = adata[:, gene_list].X
if scipy.sparse.issparse(X):
n = X.shape[0]
d_mean = density.mean()
d_std = density.std()
# Gene stats
means = np.array(X.mean(axis=0)).flatten()
sq_means = np.array(X.power(2).mean(axis=0)).flatten()
stds = np.sqrt(sq_means - means**2)
# Covariance
# X.T @ density
covs = (X.T @ density) / n - means * d_mean
corrs_val = covs / (stds * d_std + 1e-12)
corrs = pd.Series(corrs_val, index=gene_list)
else:
df_exp = pd.DataFrame(X, index=adata.obs_names, columns=gene_list)
corrs = df_exp.corrwith(pd.Series(density, index=adata.obs_names))
if file_path:
# Visualize top/bottom 10
top10 = corrs.nlargest(10)
bottom10 = corrs.nsmallest(10)
plot_data = pd.concat([bottom10, top10]).sort_values()
plt.figure(figsize=(10, 8))
sns.barplot(x=plot_data.values, y=plot_data.index, palette="vlag", orient="h")
plt.title(f'Correlation with {density_col} (Top/Bottom 10)')
plt.xlabel('Correlation coefficient')
plt.tight_layout()
plt.savefig(file_path)
plt.close()
return corrs
[docs]
def analyze_niche_membership(adata, n_clusters=15, file_path=None):
"""
Cluster cells by their niche-cluster membership vectors and visualize the result.
Uses ``adata.obsm['dist_e_agg']`` (per-cell soft membership over niche
clusters produced by :func:`calculate_niche_cluster_membership`) as the
feature space, performs Ward hierarchical clustering to partition cells
into ``n_clusters`` groups, and draws a clustermap of the membership
matrix with row-color annotations.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix containing ``obsm['dist_e_agg']``.
n_clusters : int, optional
Number of cell clusters to form. Default is 15.
file_path : str, optional
Path to save the resulting clustermap image. If ``None``, the plot is
not saved.
Returns
-------
anndata.AnnData
The input AnnData with ``obs['niche_composition_cluster']`` added
(cell cluster labels). The ``niche_composition_cluster`` key name is
preserved for backward compatibility with existing h5ad artifacts.
"""
wl.cluster_cells_by_niche_membership(adata, n_clusters=n_clusters)
wl.plot_niche_membership_clustermap(adata, file_path=file_path)
return adata