Source code for velot.pp

"""
Prepares an AnnData object for the VelOT pipeline.
Follows the scanpy convention: functions modify adata in place
and return it for optional chaining.

Typical usage::

    import velot

    # All-in-one
    velot.pp.prepare(adata, n_pcs=30, root_cluster="Root")

    # Or step by step
    velot.pp.normalize(adata)
    velot.pp.select_genes(adata, n_hvg=2000)
    velot.pp.pca(adata, n_pcs=30)
    velot.pp.neighbors(adata)
    velot.pp.umap(adata)
    velot.pp.pseudotime(adata, root_cluster="Root")
"""

from __future__ import annotations

import warnings
from typing import Optional

import numpy as np
import scanpy as sc
from anndata import AnnData


# ======================================================================
# Individual preprocessing steps
# ======================================================================


[docs] def normalize(adata: AnnData, target_sum: float = 1e4) -> AnnData: """ Total-count normalize and log-transform. Parameters ---------- adata Annotated data matrix with raw counts in adata.X. target_sum Target total counts per cell after normalization. Returns ------- adata, modified in place. """ sc.pp.normalize_total(adata, target_sum=target_sum) sc.pp.log1p(adata) return adata
[docs] def select_genes( adata: AnnData, n_hvg: int = 2000, flavor: str = "seurat", ) -> AnnData: """ Select highly variable genes and subset the data. Parameters ---------- adata Annotated data matrix (should be log-normalized). n_hvg Number of highly variable genes to keep. flavor HVG selection method passed to scanpy. Returns ------- adata, subsetted to HVGs in place. """ if adata.n_vars <= n_hvg: return adata sc.pp.highly_variable_genes( adata, n_top_genes=n_hvg, flavor=flavor, subset=False, ) adata._inplace_subset_var(adata.var["highly_variable"].values) return adata
[docs] def scale(adata: AnnData) -> AnnData: """ Scale to zero mean and unit variance per gene. This ensures PCA captures correlation structure rather than being dominated by highly-expressed genes. Returns ------- adata, modified in place. """ sc.pp.scale(adata) return adata
[docs] def pca(adata: AnnData, n_pcs: int = 30) -> AnnData: """ Compute PCA embedding. The PCA coordinates (adata.obsm['X_pca']) are the space where velocity will be computed. This is NOT just for visualization. Returns ------- adata with adata.obsm['X_pca'] populated. """ sc.tl.pca(adata, n_comps=n_pcs, svd_solver="arpack") return adata
[docs] def neighbors( adata: AnnData, n_pcs: int = 30, n_neighbors: int = 30, ) -> AnnData: """ Compute the KNN graph. The KNN graph is used downstream for: - OT cost matrix locality penalties - Velocity smoothing (KNN consistency) - Pseudotime computation (DPT) Returns ------- adata with adata.obsp['connectivities'] and adata.obsp['distances']. """ sc.pp.neighbors(adata, n_pcs=n_pcs, n_neighbors=n_neighbors) return adata
[docs] def umap(adata: AnnData) -> AnnData: """ Compute UMAP embedding (for visualization only). VelOT does NOT compute velocity in UMAP space. UMAP coordinates are used only for plotting. Returns ------- adata with adata.obsm['X_umap'] populated. """ sc.tl.umap(adata) return adata
[docs] def pseudotime( adata: AnnData, *, key: Optional[str] = None, root_cluster: Optional[str] = None, root_cell: Optional[int] = None, cluster_key: str = "clusters", ) -> AnnData: """ Compute or load pseudotime ordering. Three modes: 1. ``key`` provided: load precomputed pseudotime from adata.obs[key] 2. ``root_cell`` provided: run DPT from that cell index 3. ``root_cluster`` provided: run DPT from the first cell in that cluster DPT (Diffusion Pseudotime) computes temporal ordering from the expression geometry alone — no velocity or spliced/unspliced information is used. This keeps the velocity estimation independent. Parameters ---------- adata Must already have the KNN graph computed (run ``velot.pp.neighbors`` first). key Column name in adata.obs with precomputed pseudotime. root_cluster Cluster name to use as root for DPT. root_cell Cell index to use as root for DPT. Overrides root_cluster. cluster_key Column in adata.obs with cluster labels. Returns ------- adata with adata.obs['pseudotime'] in [0, 1]. """ if key is not None: pt = _load_pseudotime(adata, key) else: pt = _compute_dpt( adata, root_cluster=root_cluster, root_cell=root_cell, cluster_key=cluster_key, ) adata.obs["pseudotime"] = _normalize_01(pt) return adata
# ====================================================================== # All-in-one convenience function # ======================================================================
[docs] def prepare( adata: AnnData, *, n_pcs: int = 30, n_neighbors: int = 30, n_hvg: Optional[int] = 2000, pseudotime_key: Optional[str] = None, root_cluster: Optional[str] = None, root_cell: Optional[int] = None, cluster_key: str = "clusters", do_normalize: bool = True, copy: bool = True, ) -> AnnData: """ Full preprocessing in one call. Runs: normalize → select_genes → scale → PCA → neighbors → UMAP → pseudotime. Parameters ---------- adata Raw or partially processed AnnData object. n_pcs Number of principal components. n_neighbors Number of neighbors for KNN graph. n_hvg Number of HVGs to select. None to skip. pseudotime_key Precomputed pseudotime column name. If provided, DPT is skipped. root_cluster Root cluster for DPT. root_cell Root cell index for DPT. cluster_key Column with cluster labels. do_normalize Whether to normalize + log1p. Set False if already done. copy Whether to operate on a copy of adata. Returns ------- Preprocessed adata. Example ------- :: import velot import scvelo as scv adata = scv.datasets.pancreas() velot.pp.prepare(adata, root_cluster="Ductal", cluster_key="clusters") """ if copy: adata = adata.copy() if do_normalize: normalize(adata) if n_hvg is not None: select_genes(adata, n_hvg=n_hvg) scale(adata) pca(adata, n_pcs=n_pcs) neighbors(adata, n_pcs=n_pcs, n_neighbors=n_neighbors) umap(adata) pseudotime( adata, key=pseudotime_key, root_cluster=root_cluster, root_cell=root_cell, cluster_key=cluster_key, ) # Store pipeline parameters for reproducibility adata.uns["velot_params"] = { "n_pcs": n_pcs, "n_neighbors": n_neighbors, "n_hvg": n_hvg, "pseudotime_source": pseudotime_key or "dpt", } return adata
# ====================================================================== # Internal helpers # ====================================================================== def _load_pseudotime(adata: AnnData, key: str) -> np.ndarray: """Load and validate a precomputed pseudotime column.""" if key not in adata.obs: raise KeyError( f"Pseudotime column '{key}' not found in adata.obs. " f"Available columns: {list(adata.obs.columns)}" ) pt = adata.obs[key].values.astype(np.float64) n_nan = np.isnan(pt).sum() if n_nan > 0: warnings.warn( f"Pseudotime column '{key}' contains {n_nan} NaN values. " f"Assigning them maximum pseudotime.", stacklevel=2, ) pt[np.isnan(pt)] = np.nanmax(pt) return pt def _compute_dpt( adata: AnnData, *, root_cluster: Optional[str] = None, root_cell: Optional[int] = None, cluster_key: str = "clusters", ) -> np.ndarray: """ Compute Diffusion Pseudotime. DPT builds a diffusion process on the KNN graph and measures diffusion distance from a root cell. It uses only the expression geometry — no velocity or spliced/unspliced information. """ if root_cell is not None: adata.uns["iroot"] = int(root_cell) elif root_cluster is not None: if cluster_key not in adata.obs: raise ValueError( f"cluster_key '{cluster_key}' not found in adata.obs. " f"Available: {list(adata.obs.columns)}" ) mask = adata.obs[cluster_key].astype(str) == str(root_cluster) if not mask.any(): available = sorted( adata.obs[cluster_key].astype(str).unique().tolist() ) raise ValueError( f"Root cluster '{root_cluster}' not found. " f"Available: {available}" ) adata.uns["iroot"] = int(np.where(mask)[0][0]) else: raise ValueError( "For pseudotime, provide one of: " "key (precomputed), root_cluster, or root_cell." ) sc.tl.diffmap(adata) sc.tl.dpt(adata) pt = adata.obs["dpt_pseudotime"].values.astype(np.float64) # DPT can produce infinities for disconnected components inf_mask = ~np.isfinite(pt) if inf_mask.any(): warnings.warn( f"DPT produced {inf_mask.sum()} non-finite values " f"(disconnected cells). Assigning maximum pseudotime.", stacklevel=2, ) pt[inf_mask] = np.nanmax(pt[np.isfinite(pt)]) return pt def _normalize_01(pt: np.ndarray) -> np.ndarray: """Normalize an array to [0, 1], handling edge cases.""" pt = pt.astype(np.float64) pt_min = np.nanmin(pt) pt_max = np.nanmax(pt) if pt_max - pt_min < 1e-12: warnings.warn( "Pseudotime has near-zero range. Returning all zeros.", stacklevel=2, ) return np.zeros_like(pt) return (pt - pt_min) / (pt_max - pt_min)