Source code for velot.tl

"""
Core analysis functions for the VelOT pipeline.

Follows the scanpy/scvelo convention: functions modify adata in place
and return it for optional chaining.

Quick usage::

    import velot
    velot.tl.velocity(adata)   # runs the full pipeline

Step-by-step usage::

    velot.tl.build_windows(adata)
    velot.tl.compute_ot_velocity(adata)
    velot.tl.smooth_velocity(adata)
    velot.tl.project_to_umap(adata)
"""

from __future__ import annotations

import warnings
from typing import Optional

import numpy as np
from anndata import AnnData
from scipy.spatial import cKDTree
from sklearn.cluster import KMeans

try:
    import ot as pot
except ImportError:
    raise ImportError(
        "The POT (Python Optimal Transport) library is required. "
        "Install it with: pip install POT"
    )

try:
    import torch
    import torch.nn as nn
    import torch.optim as optim
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _HAS_TORCH = True
except ImportError:
    _HAS_TORCH = False
    DEVICE = None


# =====================================================================
# 1. SPATIAL-TEMPORAL WINDOWING
# =====================================================================


[docs] def build_windows( adata: AnnData, basis: str = "X_pca", n_clusters: Optional[int] = None, window_size: Optional[int] = None, overlap_fraction: float = 0.5, min_window_size: int = 20, spatial_key: Optional[str] = None, tail_handling: str = "force", # NEW: "force", "drop", or "split" tail_threshold: int = 10, # NEW: Used only if tail_handling="drop" random_state: int = 42, ) -> AnnData: """ Build spatial-temporal windows for local OT velocity computation. Instead of sorting all cells by pseudotime and creating global sequential windows (which fails when pseudotime does not correspond to spatial locality), this function: 1. Clusters cells spatially in PCA space 2. Within each cluster, sorts cells by pseudotime 3. Creates overlapping temporal windows within each cluster 4. Pairs consecutive windows for OT computation This ensures OT is only computed between spatially nearby cells, even when cells of very different pseudotime coexist in the same region of the embedding. Parameters ---------- adata Must contain ``adata.obsm[basis]`` and ``adata.obs['pseudotime']``. basis Key in ``adata.obsm`` for the embedding to cluster in. n_clusters Number of spatial clusters. If None, chosen automatically based on dataset size. window_size Number of cells per temporal window. If None, chosen automatically. overlap_fraction Fraction of overlap between consecutive temporal windows within each cluster. 0.5 means 50% overlap. min_window_size Minimum number of cells to form a valid window. Clusters smaller than 2 * min_window_size are skipped. spatial_key Column in ``adata.obs`` with precomputed spatial cluster labels. If provided, skips KMeans clustering. random_state Random seed for KMeans. Returns ------- adata, modified in place with: - ``adata.obs['velot_spatial_cluster']`` : spatial cluster labels - ``adata.uns['velot_windows']`` : dict with window pair info """ _check_fields(adata, obsm_keys=[basis], obs_keys=["pseudotime"]) X = adata.obsm[basis] pseudotime = adata.obs["pseudotime"].values n_cells = adata.n_obs # ------------------------------------------------------------------ # Step 1: Spatial clustering # ------------------------------------------------------------------ if spatial_key is not None: labels = adata.obs[spatial_key].values.astype(int) n_clust = len(np.unique(labels)) print(f" Using precomputed spatial labels from '{spatial_key}': " f"{n_clust} clusters") else: if n_clusters is None: n_clusters = max(5, min(30, n_cells // 100)) km = KMeans( n_clusters=n_clusters, random_state=random_state, n_init=10, ).fit(X) labels = km.labels_ n_clust = n_clusters print(f" Spatial clustering: {n_clust} clusters via KMeans " f"on {basis} ({X.shape[1]}D)") adata.obs["velot_spatial_cluster"] = labels # ------------------------------------------------------------------ # Step 2: Within each cluster, create temporal windows # ------------------------------------------------------------------ if window_size is None: # Aim for ~3-5 windows per cluster on average mean_cluster_size = n_cells / n_clust window_size = max(min_window_size, int(mean_cluster_size / 4)) step_size = max(1, int(window_size * (1.0 - overlap_fraction))) window_pairs = [] skipped_clusters = 0 for c in range(n_clust): cluster_mask = labels == c cluster_indices = np.where(cluster_mask)[0] if len(cluster_indices) < 2 * min_window_size: skipped_clusters += 1 continue # Sort by pseudotime within this cluster pt_local = pseudotime[cluster_indices] order = np.argsort(pt_local) sorted_indices = cluster_indices[order] # Adapt window size for small clusters local_ws = min(window_size, len(sorted_indices) // 2) local_ws = max(local_ws, min_window_size) local_step = max(1, int(local_ws * (1.0 - overlap_fraction))) # Create windows windows = [] last_start = 0 for start in range(0, len(sorted_indices) - local_ws + 1, local_step): windows.append(sorted_indices[start : start + local_ws]) last_start = start if len(windows) > 0: next_start = last_start + local_step remaining_count = len(sorted_indices) - next_start if remaining_count > 0: if tail_handling == "split" and remaining_count >= 2: # Distribute the remaining cells across two full-sized # windows to smooth the overlap and add an extra OT step step_B = remaining_count // 2 end_A = len(sorted_indices) - step_B start_A = max(0, end_A - local_ws) end_B = len(sorted_indices) start_B = max(0, end_B - local_ws) w_A = sorted_indices[start_A : end_A] w_B = sorted_indices[start_B : end_B] # Avoid duplicates if they perfectly overlap if not np.array_equal(windows[-1], w_A): windows.append(w_A) if not np.array_equal(windows[-1], w_B): windows.append(w_B) elif tail_handling == "drop": # Only append the final forced window if the remaining # cells meet the user's noise threshold if remaining_count >= tail_threshold: windows.append(sorted_indices[-local_ws:]) else: # "force" (Default behavior): Force the last window # to capture the tail, regardless of overlap spike windows.append(sorted_indices[-local_ws:]) # Pair consecutive windows for w_src, w_tgt in zip(windows[:-1], windows[1:]): window_pairs.append((w_src, w_tgt)) # ------------------------------------------------------------------ # Store results # ------------------------------------------------------------------ adata.uns["velot_windows"] = { "pairs": window_pairs, "n_pairs": len(window_pairs), "n_clusters": n_clust, "window_size": window_size, "overlap_fraction": overlap_fraction, } print(f" Built {len(window_pairs)} window pairs across " f"{n_clust - skipped_clusters} clusters " f"(skipped {skipped_clusters} small clusters)") print(f" Window size: {window_size}, step: {step_size}") return adata
# ===================================================================== # 2. OPTIMAL TRANSPORT VELOCITY # ===================================================================== def _ot_velocity_pair( X: np.ndarray, idx_source: np.ndarray, idx_target: np.ndarray, pseudotime: np.ndarray, knn_adj, reg: float = 0.05, lambda_time: float = 1.0, lambda_knn: float = 1.0, ) -> np.ndarray: """ Compute OT-based velocity for one window pair. Returns velocity vectors for cells in the source window. Adapted from the original VelOT compute_ot_velocity. """ X1 = X[idx_source] X2 = X[idx_target] n1 = X1.shape[0] n2 = X2.shape[0] # Uniform marginals a = np.ones(n1, dtype=np.float64) / n1 b = np.ones(n2, dtype=np.float64) / n2 # Cost matrix: squared Euclidean distance, normalized C = pot.dist(X1, X2, metric="sqeuclidean") C_max = C.max() if C_max > 0: C = C / C_max # Pseudotime penalty: penalize backward transport t1 = pseudotime[idx_source] t2 = pseudotime[idx_target] time_diff = t1[:, None] - t2[None, :] C[time_diff > 0] += lambda_time # KNN locality penalty: penalize transport to non-neighbors if knn_adj is not None: local_adj = knn_adj[idx_source][:, idx_target] if hasattr(local_adj, "toarray"): local_adj = local_adj.toarray() C[local_adj == 0] += lambda_knn # Sinkhorn OT try: P = pot.sinkhorn(a, b, C, reg=reg, numItermax=500, stopThr=1e-6) except Exception: # Fallback to exact OT if Sinkhorn diverges try: P = pot.emd(a, b, C) except Exception: return np.zeros_like(X1) # Check for numerical issues if not np.isfinite(P).all() or P.sum() < 1e-10: return np.zeros_like(X1) # Velocity = weighted displacement row_sums = P.sum(axis=1, keepdims=True) row_sums[row_sums == 0] = 1.0 V = (P @ X2) / row_sums - X1 # Zero out velocity for cells with no KNN neighbors in target if knn_adj is not None: has_neighbors = local_adj.sum(axis=1) > 0 V[~has_neighbors] = 0.0 return V
[docs] def compute_ot_velocity( adata: AnnData, basis: str = "X_pca", reg: float = 0.05, lambda_time: float = 1.0, lambda_knn: float = 1.0, ) -> AnnData: """ Compute raw OT velocity from spatial-temporal windows. For each window pair, Sinkhorn optimal transport is used to match cells in the source window to cells in the target window. The velocity for each cell is the weighted displacement under the transport plan. Results are aggregated across all window pairs where a cell appears as a source. A per-cell confidence score tracks how many windows contributed to each cell's velocity estimate. Parameters ---------- adata Must contain windows from ``velot.tl.build_windows()``. basis Embedding key in ``adata.obsm``. reg Sinkhorn entropy regularization. Smaller = sharper plans. lambda_time Penalty added to cost for backward-in-time transport. lambda_knn Penalty added to cost for transport between non-neighbors. Returns ------- adata, modified in place with: - ``adata.obsm['velot_velocity']`` : raw OT velocity (PCA space) - ``adata.obs['velot_confidence']`` : contribution count per cell """ _check_fields(adata, obsm_keys=[basis], uns_keys=["velot_windows"]) X = adata.obsm[basis] pseudotime = adata.obs["pseudotime"].values n_cells = adata.n_obs dim = X.shape[1] window_info = adata.uns["velot_windows"] window_pairs = window_info["pairs"] # Get KNN adjacency for locality penalty knn_adj = None if "connectivities" in adata.obsp: knn_adj = adata.obsp["connectivities"] # Accumulate velocity across all window pairs V = np.zeros((n_cells, dim), dtype=np.float64) counts = np.zeros(n_cells, dtype=np.float64) for pair_i, (idx_src, idx_tgt) in enumerate(window_pairs): v_local = _ot_velocity_pair( X, idx_src, idx_tgt, pseudotime, knn_adj, reg=reg, lambda_time=lambda_time, lambda_knn=lambda_knn, ) for i, cell_idx in enumerate(idx_src): V[cell_idx] += v_local[i] counts[cell_idx] += 1 # Average velocity across contributing windows mask = counts > 0 V[mask] /= counts[mask, None] # Confidence: normalized contribution count max_count = counts.max() if counts.max() > 0 else 1.0 confidence = counts / max_count if basis == "X_pca": adata.obsm["velot_velocity_raw_pca"] = V elif basis == "X_umap": adata.obsm["velot_velocity_raw_umap"] = V else: raise NotImplemented adata.obs["velot_confidence"] = confidence n_with_velocity = mask.sum() n_zero = (~mask).sum() print(f" OT velocity computed: {n_with_velocity} cells with velocity, " f"{n_zero} cells without (will be filled by smoothing)") return adata
# ===================================================================== # 3. NEURAL VELOCITY FIELD SMOOTHING # ===================================================================== def _build_knn_index(X: np.ndarray, k: int = 15) -> np.ndarray: """ Build a fixed-k nearest neighbor index using cKDTree. Returns ------- knn_indices : ndarray of shape (n_cells, k) For each cell, the indices of its k nearest neighbors. """ tree = cKDTree(X) _, indices = tree.query(X, k=k + 1) # +1 because self is included return indices[:, 1:] # exclude self class _VelocityNet(nn.Module): """Small MLP that predicts velocity from PCA coordinates.""" def __init__(self, dim: int, hidden: int = 128, use_pseudotime: bool = True): super().__init__() self.use_pseudotime = use_pseudotime in_dim = dim + (1 if use_pseudotime else 0) self.net = nn.Sequential( nn.Linear(in_dim, hidden), nn.SiLU(), nn.Linear(hidden, hidden), nn.SiLU(), nn.Linear(hidden, dim), ) def forward(self, x: torch.Tensor, pt: torch.Tensor = None) -> torch.Tensor: if self.use_pseudotime and pt is not None: x = torch.cat([x, pt], dim=-1) return self.net(x) def forward_with_jacobian(self, x: torch.Tensor, pt: torch.Tensor = None): """ Forward pass that also computes the Jacobian dv/dx. Used for curl and divergence regularization. """ x_in = x.detach().clone().requires_grad_(True) if self.use_pseudotime and pt is not None: inp = torch.cat([x_in, pt], dim=-1) else: inp = x_in v = self.net(inp) return v, x_in
[docs] def smooth_velocity( adata: AnnData, basis: str = "pca", velocity_key: str = "velot_velocity_raw", n_epochs: int = 200, hidden_dim: int = 128, lr: float = 1e-3, batch_size: int = 256, lambda_smooth: float = 0.5, lambda_curl: float = 0.5, lambda_divergence: float = 0.0, k_smooth: int = 15, use_pseudotime: bool = True, random_state: int = 42, verbose: bool = True, ) -> AnnData: """ Smooth the raw OT velocity field using a neural network. Three types of regularization are available: - **Smoothness** (``lambda_smooth``): neighboring cells should have similar velocities. Reduces noise and zig-zagging. - **Curl penalty** (``lambda_curl``): penalizes rotational components of the velocity field. Encourages irrotational flow where streamlines do not form loops. This is the strongest constraint for preventing crossing field lines. - **Divergence penalty** (``lambda_divergence``): penalizes the divergence of the velocity field. Encourages incompressible-like flow where cells neither accumulate nor deplete locally. Parameters ---------- adata Must contain ``adata.obsm['velot_velocity']``. basis Embedding key for cell coordinates. n_epochs Training epochs. hidden_dim Network hidden layer width. lr Learning rate. batch_size Cells per batch. lambda_smooth Weight of KNN smoothness loss. lambda_curl Weight of curl penalty. Higher values produce flow with fewer crossing streamlines. Recommended range: 0.0 (off) to 0.5. lambda_divergence Weight of divergence penalty. Higher values produce more volume-preserving flow. Recommended range: 0.0 (off) to 0.5. k_smooth Number of neighbors for smoothness. use_pseudotime Condition network on pseudotime. random_state Random seed. verbose Print progress. Returns ------- adata with smoothed velocity. """ if not _HAS_TORCH: raise ImportError( "PyTorch is required for velocity smoothing." "Install with: pip3 install torch" ) else: print(f"Found torch compatible version. Running on {DEVICE} device") _check_fields( adata, obsm_keys=[f"X_{basis}", velocity_key], obs_keys=["velot_confidence"], ) torch.manual_seed(random_state) np.random.seed(random_state) X_np = adata.obsm[f"X_{basis}"].astype(np.float32) V_np = adata.obsm[velocity_key].astype(np.float32) conf_np = adata.obs["velot_confidence"].values.astype(np.float32) pt_np = adata.obs["pseudotime"].values.astype(np.float32) n_cells, dim = X_np.shape # adata.obsm[f"{velocity_key}_raw"] = V_np.copy() knn_indices = _build_knn_index(X_np, k=k_smooth) X_t = torch.tensor(X_np, device=DEVICE) V_t = torch.tensor(V_np, device=DEVICE) conf_t = torch.tensor(conf_np, device=DEVICE) pt_t = torch.tensor(pt_np, device=DEVICE).unsqueeze(-1) knn_t = torch.tensor(knn_indices, dtype=torch.long, device=DEVICE) net = _VelocityNet( dim=dim, hidden=hidden_dim, use_pseudotime=use_pseudotime, ).to(DEVICE) optimizer = optim.Adam(net.parameters(), lr=lr) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs) use_jacobian = (lambda_curl > 0) or (lambda_divergence > 0) net.train() losses_reg = [] losses_smooth = [] losses_curl = [] losses_div = [] losses_total = [] # Track initial loss scales for normalization init_scales = {} warmup_epochs = 20 # use raw weights during warmup for epoch in range(n_epochs): idx = torch.randint(0, n_cells, (min(batch_size, n_cells),), device=DEVICE) B = idx.shape[0] k = knn_indices.shape[1] nbr_idx = knn_t[idx] all_idx = torch.cat([idx, nbr_idx.reshape(-1)]) X_all = X_t[all_idx] pt_all = pt_t[all_idx] if use_pseudotime else None V_all = net(X_all, pt_all) V_batch = V_all[:B] V_nbr = V_all[B:].reshape(B, k, dim) V_target = V_t[idx] conf_batch = conf_t[idx] sq_err = ((V_batch - V_target) ** 2).sum(dim=-1) loss_reg = (conf_batch * sq_err).sum() / (conf_batch.sum() + 1e-8) diff = V_batch.unsqueeze(1) - V_nbr loss_smooth = (diff ** 2).mean() # Curl and divergence loss_curl_val = torch.tensor(0.0, device=DEVICE) loss_div_val = torch.tensor(0.0, device=DEVICE) if use_jacobian: x_jac = X_t[idx].detach().clone().requires_grad_(True) pt_jac = pt_t[idx] if use_pseudotime else None if use_pseudotime and pt_jac is not None: inp_jac = torch.cat([x_jac, pt_jac], dim=-1) else: inp_jac = x_jac v_jac = net.net(inp_jac) jac_rows = [] for d in range(min(dim, 3)): grad_d = torch.autograd.grad( v_jac[:, d].sum(), x_jac, create_graph=True, retain_graph=True, )[0] jac_rows.append(grad_d) if len(jac_rows) >= 2: if lambda_divergence > 0: divergence = sum( jac_rows[d][:, d] for d in range(min(dim, len(jac_rows))) ) loss_div_val = (divergence ** 2).mean() if lambda_curl > 0: curl_components = [] n_curl_dims = min(dim, len(jac_rows)) for d1 in range(n_curl_dims): for d2 in range(d1 + 1, n_curl_dims): curl_d1d2 = jac_rows[d2][:, d1] - jac_rows[d1][:, d2] curl_components.append(curl_d1d2) if curl_components: curl_stack = torch.stack(curl_components, dim=-1) loss_curl_val = (curl_stack ** 2).mean() # Record initial scales after warmup if epoch == warmup_epochs: init_scales["reg"] = max(loss_reg.item(), 1e-6) init_scales["smooth"] = max(loss_smooth.item(), 1e-6) if lambda_curl > 0: init_scales["curl"] = max(loss_curl_val.item(), 1e-6) if lambda_divergence > 0: init_scales["div"] = max(loss_div_val.item(), 1e-6) # Compute total loss with scale normalization if epoch < warmup_epochs or not init_scales: # During warmup: just regression + smoothness with raw weights loss = loss_reg + lambda_smooth * loss_smooth if lambda_curl > 0: loss = loss + lambda_curl * loss_curl_val if lambda_divergence > 0: loss = loss + lambda_divergence * loss_div_val else: # After warmup: normalize each loss by its initial scale # so that lambda values have comparable effect loss = loss_reg / init_scales["reg"] loss = loss + lambda_smooth * loss_smooth / init_scales["smooth"] if lambda_curl > 0: loss = loss + lambda_curl * loss_curl_val / init_scales["curl"] if lambda_divergence > 0: loss = loss + lambda_divergence * loss_div_val / init_scales["div"] optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() losses_reg.append(loss_reg.item()) losses_smooth.append(loss_smooth.item()) losses_curl.append(loss_curl_val.item()) losses_div.append(loss_div_val.item()) losses_total.append(loss.item()) if verbose and (epoch + 1) % 50 == 0: msg = (f" epoch {epoch+1}/{n_epochs}: " f"total={loss.item():.4f} " f"reg={loss_reg.item():.4f} " f"smooth={loss_smooth.item():.4f}") if lambda_curl > 0: msg += f" curl={loss_curl_val.item():.4f}" if lambda_divergence > 0: msg += f" div={loss_div_val.item():.4f}" print(msg) # Extract smoothed velocity net.eval() with torch.no_grad(): chunk_size = 2048 V_smooth = [] for start in range(0, n_cells, chunk_size): end = min(start + chunk_size, n_cells) X_chunk = X_t[start:end] pt_chunk = pt_t[start:end] if use_pseudotime else None V_smooth.append(net(X_chunk, pt_chunk).cpu().numpy()) V_smooth = np.concatenate(V_smooth, axis=0) adata.obsm[f"velot_velocity_{basis}"] = V_smooth adata.uns["velot_smoothing"] = { "n_epochs": n_epochs, "losses_regression": losses_reg, "losses_smoothness": losses_smooth, "losses_curl": losses_curl, "losses_divergence": losses_div, "losses_total": losses_total, "lambda_smooth": lambda_smooth, "lambda_curl": lambda_curl, "lambda_divergence": lambda_divergence, "network": net, "device": str(DEVICE), "dim": dim, "use_pseudotime": use_pseudotime, } if verbose: print(f" Smoothing complete: final total={losses_total[-1]:.4f}") return adata
[docs] def query_velocity( adata: AnnData, positions: np.ndarray, pseudotime_values: Optional[np.ndarray] = None, ) -> np.ndarray: """ Query the learned continuous velocity field at arbitrary positions. Unlike the stored velocity vectors in ``adata.obsm['velot_velocity']`` which are only defined at cell locations, this function evaluates the trained neural velocity field at any point in PCA space. This is what makes VelOT a true velocity FIELD rather than just a set of velocity vectors: given any point on the manifold, you can ask "what is the velocity here?" Parameters ---------- positions Array of shape (n_points, D) with PCA coordinates to query. pseudotime_values Array of shape (n_points,) with pseudotime values for each query point. If None, uses the median pseudotime from the dataset. Returns ------- Array of shape (n_points, D) with velocity vectors at each query position. Example ------- :: # Query at cell positions (same as stored vectors) V = velot.tl.query_velocity(adata, adata.obsm["X_pca"]) # Query at arbitrary positions import numpy as np grid_points = np.random.randn(100, 30) * 0.1 V_grid = velot.tl.query_velocity(adata, grid_points) """ if "velot_smoothing" not in adata.uns: raise ValueError( "No smoothing network found. Run velot.tl.smooth_velocity() first." ) info = adata.uns["velot_smoothing"] net = info["network"] use_pt = info["use_pseudotime"] positions = np.asarray(positions, dtype=np.float32) n_points = positions.shape[0] if pseudotime_values is None: median_pt = float(np.median(adata.obs["pseudotime"].values)) pseudotime_values = np.full(n_points, median_pt, dtype=np.float32) else: pseudotime_values = np.asarray(pseudotime_values, dtype=np.float32) X_t = torch.tensor(positions, device=DEVICE) pt_t = torch.tensor(pseudotime_values, device=DEVICE).unsqueeze(-1) net.eval() with torch.no_grad(): chunk_size = 2048 V_parts = [] for start in range(0, n_points, chunk_size): end = min(start + chunk_size, n_points) X_chunk = X_t[start:end] pt_chunk = pt_t[start:end] if use_pt else None V_parts.append(net(X_chunk, pt_chunk).cpu().numpy()) return np.concatenate(V_parts, axis=0)
# ===================================================================== # 4. PROJECTION TO UMAP # =====================================================================
[docs] def project_to_embedding( adata: AnnData, velocity_key: str = "velot_velocity_pca", velocity_key_umap: str = "velot_velocity_umap", basis_pca: str = "X_pca", basis_embedding: str = "X_umap", n_neighbors: int = 30, ) -> AnnData: """ Project velocity from PCA space to any 2D embedding space for visualization. Uses a local linear approximation: for each cell, the Jacobian of the PCA→2D (UMAP, TSNE...) mapping is estimated from its KNN neighborhood via least-squares, and the PCA velocity is transformed accordingly. Note: UMAP is for visualization only. The velocity in PCA space (``adata.obsm['velot_velocity']``) is the primary output. Parameters ---------- adata Must contain PCA and UMAP embeddings, and computed velocity. Returns ------- adata with ``adata.obsm['velocity_umap']`` populated. """ _check_fields( adata, obsm_keys=[velocity_key, basis_pca, basis_embedding], ) if basis_embedding == "X_tsne": import warnings warnings.warn( "t-SNE distorts global distances. Velocity arrows show " "local directions correctly but arrow lengths and " "cross-cluster directions may be misleading. " "Consider using UMAP for velocity visualization.", UserWarning, ) X_pca = adata.obsm[basis_pca] X_umap = adata.obsm[basis_embedding] V_pca = adata.obsm[velocity_key] n_cells = X_pca.shape[0] # Build KNN in PCA space for the local linear approximation knn_indices = _build_knn_index(X_pca, k=n_neighbors) V_umap = np.zeros_like(X_umap) for i in range(n_cells): nbrs = knn_indices[i] # Local displacement in PCA and UMAP dX = X_pca[nbrs] - X_pca[i] # (k, d_pca) dU = X_umap[nbrs] - X_umap[i] # (k, 2) # Least-squares: dU ≈ dX @ A → A = (dX^T dX)^{-1} dX^T dU A, _, _, _ = np.linalg.lstsq(dX, dU, rcond=None) # Project PCA velocity through the local Jacobian V_umap[i] = V_pca[i] @ A adata.obsm[velocity_key_umap] = V_umap print(f" Velocity projected to 2D embedding {basis_embedding} ({n_cells} cells)") return adata
[docs] def project_to_umap( adata: AnnData, velocity_key: str = "velot_velocity_pca", velocity_key_umap: str = "velot_velocity_umap", basis_pca: str = "X_pca", basis_umap: str = "X_umap", n_neighbors: int = 30, ) -> AnnData: """ Project velocity from PCA space to UMAP space for visualization. Uses a local linear approximation: for each cell, the Jacobian of the PCA→UMAP mapping is estimated from its KNN neighborhood via least-squares, and the PCA velocity is transformed accordingly. Note: UMAP is for visualization only. The velocity in PCA space (``adata.obsm['velot_velocity']``) is the primary output. Parameters ---------- adata Must contain PCA and UMAP embeddings, and computed velocity. Returns ------- adata with ``adata.obsm['velocity_umap']`` populated. """ _check_fields( adata, obsm_keys=[velocity_key, basis_pca, basis_umap], ) X_pca = adata.obsm[basis_pca] X_umap = adata.obsm[basis_umap] V_pca = adata.obsm[velocity_key] n_cells = X_pca.shape[0] # Build KNN in PCA space for the local linear approximation knn_indices = _build_knn_index(X_pca, k=n_neighbors) V_umap = np.zeros_like(X_umap) for i in range(n_cells): nbrs = knn_indices[i] # Local displacement in PCA and UMAP dX = X_pca[nbrs] - X_pca[i] # (k, d_pca) dU = X_umap[nbrs] - X_umap[i] # (k, 2) # Least-squares: dU ≈ dX @ A → A = (dX^T dX)^{-1} dX^T dU A, _, _, _ = np.linalg.lstsq(dX, dU, rcond=None) # Project PCA velocity through the local Jacobian V_umap[i] = V_pca[i] @ A adata.obsm[velocity_key_umap] = V_umap print(f" Velocity projected to UMAP ({n_cells} cells)") return adata
# ===================================================================== # 5. ORCHESTRATOR # =====================================================================
[docs] def velocity( adata: AnnData, basis: str = "X_pca", smooth: bool = True, # Windowing params n_clusters: Optional[int] = None, window_size: Optional[int] = None, overlap_fraction: float = 0.5, min_window_size: int = 20, spatial_key: Optional[str] = None, tail_handling: str = "force", tail_threshold: int = 10, # OT params reg: float = 0.05, lambda_time: float = 1.0, lambda_knn: float = 1.0, # Smoothing params n_epochs: int = 200, hidden_dim: int = 128, lambda_smooth: float = 0.5, lambda_curl: float = 0.0, lambda_divergence: float = 0.0, k_smooth: int = 15, use_pseudotime: bool = True, # Output project_umap: bool = True, project_basis: str = "X_umap", random_state: int = 42, verbose: bool = True, ) -> AnnData: """ Run the full VelOT velocity pipeline. This is a convenience function that calls, in order: 1. ``build_windows()`` — spatial-temporal windowing 2. ``compute_ot_velocity()`` — local OT velocity 3. ``smooth_velocity()`` — neural smoothing (optional) 4. ``project_to_umap()`` — PCA → UMAP projection (optional) Parameters ---------- adata Preprocessed AnnData (run ``velot.pp.prepare()`` first). basis Embedding key for velocity computation. smooth Whether to apply neural smoothing. project_umap Whether to project velocity to UMAP for visualization. verbose Whether to print progress. Returns ------- adata with velocity fields computed. Example ------- :: import velot import scvelo as scv adata = scv.datasets.pancreas() velot.pp.prepare(adata, root_cluster="Ductal") velot.tl.velocity(adata) velot.pl.velocity_stream(adata) """ _check_fields(adata, obsm_keys=[basis], obs_keys=["pseudotime"]) if verbose: print("VelOT: Computing velocity field") print(f" Basis: {basis} ({adata.obsm[basis].shape[1]}D)") print(f" Cells: {adata.n_obs}") print(f" Smoothing: {'ON' if smooth else 'OFF'}") print() # Step 1: Windowing if verbose: print("[1/4] Building spatial-temporal windows...") build_windows( adata, basis=basis, n_clusters=n_clusters, window_size=window_size, overlap_fraction=overlap_fraction, min_window_size=min_window_size, spatial_key=spatial_key, tail_handling=tail_handling, tail_threshold=tail_threshold, random_state=random_state, ) # Step 2: OT velocity if verbose: print("\n[2/4] Computing OT velocity...") compute_ot_velocity( adata, basis=basis, reg=reg, lambda_time=lambda_time, lambda_knn=lambda_knn, ) # Step 3: Smoothing if smooth: if verbose: print("\n[3/4] Smoothing velocity field...") v_key = "velot_velocity_raw_pca" if basis == "X_pca" else "velot_velocity_raw_umap" smooth_velocity( adata, basis=basis.split("X_")[1], velocity_key=v_key, n_epochs=n_epochs, hidden_dim=hidden_dim, lambda_smooth=lambda_smooth, lambda_curl=lambda_curl, lambda_divergence=lambda_divergence, k_smooth=k_smooth, use_pseudotime=use_pseudotime, random_state=random_state, verbose=verbose, ) else: if verbose: print("\n[3/4] Smoothing: SKIPPED") # Step 4: Project to UMAP if project_umap and project_basis in adata.obsm: basis_name = project_basis.split("X_")[1] if verbose: print("\n[4/4] Projecting to UMAP...") project_to_umap(adata, "velot_velocity_raw_pca", f"velot_velocity_raw_{basis_name}", basis_umap=project_basis) if smooth: project_to_umap(adata, "velot_velocity_pca", f"velot_velocity_{basis_name}", basis_umap=project_basis) else: if verbose: print("\n[4/4] UMAP projection: SKIPPED") if verbose: print("\nVelOT: Done.") return adata
# ===================================================================== # INTERNAL HELPERS # ===================================================================== def _check_fields( adata: AnnData, obsm_keys: list = None, obs_keys: list = None, uns_keys: list = None, ): """Validate that required fields exist in adata.""" obsm_keys = obsm_keys or [] obs_keys = obs_keys or [] uns_keys = uns_keys or [] for key in obsm_keys: if key not in adata.obsm: raise ValueError( f"'{key}' not found in adata.obsm. " f"Run velot.pp.prepare() first. " f"Available keys: {list(adata.obsm.keys())}" ) for key in obs_keys: if key not in adata.obs: raise ValueError( f"'{key}' not found in adata.obs. " f"Run velot.pp.prepare() first. " f"Available columns: {list(adata.obs.columns)}" ) for key in uns_keys: if key not in adata.uns: raise ValueError( f"'{key}' not found in adata.uns. " f"Run the required upstream step first." ) # ===================================================================== # 6. GRID SEARCH # =====================================================================
[docs] def gridsearch( adata: AnnData, param_grid: dict, cluster_edges: Optional[list] = None, cluster_key: str = "clusters", velocity_metric_key: str = "velocity_umap", fixed_params: Optional[dict] = None, pseudotime_key: str = "pseudotime", verbose: bool = True, ): """ Run the VelOT pipeline across multiple parameter combinations and collect evaluation metrics for each. Parameters ---------- adata Preprocessed AnnData. Must already have PCA, neighbors, UMAP, and pseudotime computed (everything from ``velot.pp``). A fresh copy is used for each parameter combination. param_grid Dictionary mapping parameter names to lists of values. Parameter names must match arguments of ``velot.tl.velocity()``. Example:: param_grid = { "basis": ["X_pca"], "reg": [0.01, 0.05, 0.1], "lambda_smooth": [0.1, 0.5, 1.0], "n_clusters": [10, 20], } cluster_edges Transition edges for CBDir metric. If None, only ICCoh is computed. cluster_key Column in adata.obs with cluster labels. velocity_metric_key Key in adata.obsm to evaluate metrics on. fixed_params Parameters passed to ``velot.tl.velocity()`` for every run that are NOT part of the grid. Example:: fixed_params = {"smooth": True, "n_epochs": 200} pseudotime_key Column in adata.obs with pseudotime. Used to verify it exists before running. verbose Whether to print progress. Returns ------- pandas DataFrame with one row per parameter combination and columns for each parameter, ICCoh mean, CBDir mean, and per-cluster/per-edge scores. Example ------- :: import velot adata = velot.datasets.dentategyrus() # ... preprocessing ... param_grid = { "reg": [0.01, 0.05, 0.1, 0.5], "lambda_smooth": [0.1, 0.5, 1.0], "n_clusters": [10, 20, 30], } edges = [ ("OPC", "OL"), ("Neuroblast", "Granule immature"), ] results = velot.tl.gridsearch( adata, param_grid, cluster_edges=edges, cluster_key="clusters", fixed_params={"basis": "X_pca", "n_epochs": 200}, ) # Best by ICCoh print(results.sort_values("iccoh_mean", ascending=False).head()) # Best by CBDir print(results.sort_values("cbdir_mean", ascending=False).head()) """ import pandas as pd from itertools import product import time as _time # Validate if pseudotime_key not in adata.obs: raise ValueError( f"'{pseudotime_key}' not found in adata.obs. " f"Run velot.pp.pseudotime() first." ) fixed_params = fixed_params or {} # Build all combinations param_names = sorted(param_grid.keys()) param_values = [param_grid[k] for k in param_names] combinations = list(product(*param_values)) n_combos = len(combinations) if verbose: print(f"VelOT Grid Search: {n_combos} combinations") print(f" Parameters: {param_names}") print(f" Fixed: {fixed_params}") print() # Run each combination rows = [] try: from tqdm.auto import tqdm iterator = tqdm( enumerate(combinations), total=n_combos, desc="VelOT Grid Search", disable=not verbose, ) except ImportError: iterator = enumerate(combinations) if verbose: warnings.warn( "Install tqdm for progress bars: pip install tqdm", stacklevel=2, ) for combo_i, combo in iterator: params = dict(zip(param_names, combo)) run_params = {**fixed_params, **params} # Update progress bar description if hasattr(iterator, "set_postfix"): short_params = {k: v for k, v in params.items()} iterator.set_postfix(short_params, refresh=True) # Work on a fresh copy each time ad = adata.copy() t0 = _time.time() try: velocity( ad, verbose=False, **run_params, ) elapsed = _time.time() - t0 # Compute metrics from . import metrics as _metrics iccoh_scores, iccoh_mean = _metrics.inner_cluster_coherence( ad, cluster_key=cluster_key, velocity_key=velocity_metric_key, ) row = {**params} row["iccoh_mean"] = iccoh_mean for cat, score in iccoh_scores.items(): row[f"iccoh_{cat}"] = score if cluster_edges is not None: cbdir_scores, cbdir_mean = _metrics.cross_boundary_correctness( ad, cluster_edges, cluster_key=cluster_key, velocity_key=velocity_metric_key, ) row["cbdir_mean"] = cbdir_mean for (u, v), score in cbdir_scores.items(): row[f"cbdir_{u}_to_{v}"] = score else: row["cbdir_mean"] = float("nan") row["elapsed_seconds"] = elapsed row["status"] = "ok" # Update progress bar with latest metrics if hasattr(iterator, "set_postfix"): post = {**short_params} post["ICCoh"] = f"{iccoh_mean:.3f}" if not np.isnan(row.get("cbdir_mean", float("nan"))): post["CBDir"] = f"{row['cbdir_mean']:.3f}" iterator.set_postfix(post, refresh=True) except Exception as e: row = {**params} row["iccoh_mean"] = float("nan") row["cbdir_mean"] = float("nan") row["elapsed_seconds"] = _time.time() - t0 row["status"] = f"error: {str(e)[:80]}" if hasattr(iterator, "set_postfix"): iterator.set_postfix({"status": "FAILED"}, refresh=True) rows.append(row) if verbose and row["status"] == "ok": msg = f" ICCoh={row['iccoh_mean']:.3f}" if "cbdir_mean" in row and not np.isnan(row.get("cbdir_mean", float("nan"))): msg += f" CBDir={row['cbdir_mean']:.3f}" msg += f" ({elapsed:.1f}s)" print(msg) # Build DataFrame df = pd.DataFrame(rows) # Reorder columns: params first, then summary metrics, then details param_cols = param_names summary_cols = ["iccoh_mean", "cbdir_mean", "elapsed_seconds", "status"] detail_cols = [c for c in df.columns if c not in param_cols + summary_cols] col_order = param_cols + summary_cols + sorted(detail_cols) df = df[[c for c in col_order if c in df.columns]] if verbose: print() print("=" * 60) print("Grid Search Complete") print("=" * 60) best_iccoh = df.loc[df["iccoh_mean"].idxmax()] print(f"\nBest ICCoh ({best_iccoh['iccoh_mean']:.3f}):") for p in param_names: print(f" {p}: {best_iccoh[p]}") if cluster_edges is not None: best_cbdir = df.loc[df["cbdir_mean"].idxmax()] print(f"\nBest CBDir ({best_cbdir['cbdir_mean']:.3f}):") for p in param_names: print(f" {p}: {best_cbdir[p]}") return df
# ===================================================================== # 7. CELL TRAJECTORY TRACING # =====================================================================
[docs] def compute_trajectories( adata: AnnData, start_cells: Optional[np.ndarray] = None, start_cluster: Optional[str] = None, target_cluster: Optional[str] = None, start_pseudotime: Optional[float] = None, end_pseudotime: Optional[float] = None, n_trajectories: int = 20, n_steps: int = 200, step_size: float = 0.05, direction: str = "forward", velocity_key: str = "velot_velocity_pca", basis: str = "X_pca", cluster_key: str = "clusters", k_velocity: int = 15, use_network: bool = True, evolve_pseudotime: bool = True, max_attempts_factor: int = 10, random_state: int = 42, ) -> AnnData: """ Compute cell trajectories by integrating the velocity field. Supports three modes: - ``"forward"``: follow the flow from starting cells. - ``"backward"``: go against the flow to trace origins. - ``"both"``: compute both directions. Starting cells can be selected by: - Explicit indices (``start_cells``). - Cluster name (``start_cluster``). - Pseudotime range (``start_pseudotime`` / ``end_pseudotime``). - Any combination: cluster + pseudotime range narrows the selection. When ``target_cluster`` is specified, only trajectories that reach the target are kept. The function will attempt up to ``n_trajectories * max_attempts_factor`` integrations to find enough successful trajectories. Parameters ---------- adata Must have velocity computed. start_cells Array of cell indices to start from. If provided, ``start_cluster`` and pseudotime range are ignored. start_cluster Cluster name to select starting cells from. Can be combined with ``start_pseudotime`` / ``end_pseudotime`` to further narrow the selection. target_cluster If provided, only keep trajectories whose terminal cluster (forward) or origin cluster (backward) matches this value. start_pseudotime Lower bound of pseudotime for selecting starting cells. Defaults to ``None`` (no lower bound). Can be used alone or combined with ``start_cluster``. end_pseudotime Upper bound of pseudotime for selecting starting cells. Defaults to ``None`` (no upper bound). Can be used alone or combined with ``start_cluster``. n_trajectories Number of successful trajectories to collect. n_steps Maximum integration steps. step_size Euler step size as fraction of mean velocity magnitude. direction ``"forward"``, ``"backward"``, or ``"both"``. velocity_key Key in adata.obsm with velocity vectors. basis Embedding key for cell coordinates. cluster_key Cluster annotation column. k_velocity KNN for velocity interpolation (fallback mode). use_network If True and the smoothing network is available, query the continuous velocity field directly. evolve_pseudotime If True, pseudotime evolves along the trajectory. max_attempts_factor When ``target_cluster`` is set, try up to ``n_trajectories * max_attempts_factor`` starting cells to find enough trajectories that reach the target. random_state Random seed. Returns ------- adata with ``adata.uns['velot_trajectories']``. The stored metadata for each trajectory includes an ``"id"`` field that can be used with ``velot.pl.trajectories(trajectory_ids=...)`` to plot specific trajectories. Examples -------- All backward trajectories from Alpha:: velot.tl.compute_trajectories( adata, start_cluster="Alpha", direction="backward", ) Only backward trajectories from Alpha that reach Ngn3 low EP:: velot.tl.compute_trajectories( adata, start_cluster="Alpha", direction="backward", target_cluster="Ngn3 low EP", n_trajectories=5, ) Forward from progenitors that reach Epsilon:: velot.tl.compute_trajectories( adata, start_cluster="Ngn3 low EP", direction="forward", target_cluster="Epsilon", n_trajectories=10, ) Forward from early cells (pseudotime 0 to 0.1) regardless of cluster:: velot.tl.compute_trajectories( adata, start_pseudotime=0.0, end_pseudotime=0.1, direction="forward", ) Backward from Alpha cells with pseudotime > 0.8:: velot.tl.compute_trajectories( adata, start_cluster="Alpha", start_pseudotime=0.8, direction="backward", ) """ if not _HAS_TORCH: raise ImportError( "PyTorch is required for velocity smoothing." "Install with: pip3 install torch" ) else: print(f"Found torch compatible version. Running on {DEVICE} device") _check_fields(adata, obsm_keys=[basis, velocity_key]) if direction not in ("forward", "backward", "both"): raise ValueError( f"direction must be 'forward', 'backward', or 'both', " f"got '{direction}'." ) X = adata.obsm[basis] V = adata.obsm[velocity_key] pseudotime = adata.obs["pseudotime"].values n_cells, dim = X.shape tree = cKDTree(X) # Check if continuous field is available has_network = ( use_network and "velot_smoothing" in adata.uns and "network" in adata.uns["velot_smoothing"] ) if has_network: net = adata.uns["velot_smoothing"]["network"] use_pt = adata.uns["velot_smoothing"]["use_pseudotime"] net.eval() # ------------------------------------------------------------------ # Select candidate starting cells # Priority: start_cells > (start_cluster and/or pseudotime range) # ------------------------------------------------------------------ has_pt_range = (start_pseudotime is not None or end_pseudotime is not None) if start_cells is not None: # Explicit indices — use as-is, no further filtering all_candidates = np.asarray(start_cells) elif start_cluster is not None or has_pt_range: # Start with all cells, then apply filters if start_cluster is not None: # Cluster filter if cluster_key not in adata.obs: raise ValueError(f"'{cluster_key}' not found in adata.obs.") cluster_mask = adata.obs[cluster_key].astype(str) == str(start_cluster) all_candidates = np.where(cluster_mask)[0] if len(all_candidates) == 0: available = sorted( adata.obs[cluster_key].astype(str).unique().tolist() ) raise ValueError( f"Cluster '{start_cluster}' not found. " f"Available: {available}" ) else: # No cluster filter — all cells are candidates all_candidates = np.arange(n_cells) # Apply pseudotime range filter on top if has_pt_range: pt_min = start_pseudotime if start_pseudotime is not None else -np.inf pt_max = end_pseudotime if end_pseudotime is not None else np.inf pt_vals = pseudotime[all_candidates] pt_mask = (pt_vals >= pt_min) & (pt_vals <= pt_max) all_candidates = all_candidates[pt_mask] if len(all_candidates) == 0: # Helpful error message if start_cluster is not None: cluster_pts = pseudotime[ adata.obs[cluster_key].astype(str) == str(start_cluster) ] raise ValueError( f"No cells in cluster '{start_cluster}' with " f"pseudotime in [{pt_min}, {pt_max}]. " f"Cluster pseudotime range: " f"[{cluster_pts.min():.3f}, {cluster_pts.max():.3f}]" ) else: raise ValueError( f"No cells with pseudotime in [{pt_min}, {pt_max}]. " f"Data pseudotime range: " f"[{pseudotime.min():.3f}, {pseudotime.max():.3f}]" ) else: raise ValueError( "Provide at least one of: start_cells, start_cluster, " "or a pseudotime range (start_pseudotime / end_pseudotime)." ) # Report selection selection_desc = [] if start_cluster is not None: selection_desc.append(f"cluster='{start_cluster}'") if has_pt_range: pt_min_str = f"{start_pseudotime:.3f}" if start_pseudotime is not None else "min" pt_max_str = f"{end_pseudotime:.3f}" if end_pseudotime is not None else "max" selection_desc.append(f"pseudotime=[{pt_min_str}, {pt_max_str}]") if selection_desc: print(f" Starting cells: {len(all_candidates)} candidates " f"({', '.join(selection_desc)})") rng = np.random.RandomState(random_state) # Step size: proportional to mean inter-cell distance sample_idx = np.random.choice(n_cells, min(500, n_cells), replace=False) mean_nn_dist = np.mean(tree.query(X[sample_idx], k=2)[0][:, 1]) dt = step_size * mean_nn_dist # Manifold boundary: generous threshold boundary_threshold = 10 * mean_nn_dist # UMAP availability has_umap = "X_umap" in adata.obsm if has_umap: X_umap = adata.obsm["X_umap"] def _velocity_at(pos, pt_value): if has_network: pos_t = torch.tensor( pos.reshape(1, -1).astype(np.float32), device=DEVICE ) pt_t = torch.tensor( [[pt_value]], dtype=torch.float32, device=DEVICE ) with torch.no_grad(): v = net(pos_t, pt_t if use_pt else None) return v.cpu().numpy().flatten() else: dists, indices = tree.query(pos, k=k_velocity) weights = 1.0 / (dists + 1e-10) weights = weights / weights.sum() return np.sum(V[indices] * weights[:, None], axis=0) def _estimate_pseudotime_at(pos): dists, indices = tree.query(pos, k=min(10, n_cells - 1)) weights = 1.0 / (dists + 1e-10) weights = weights / weights.sum() return float(np.sum(pseudotime[indices] * weights)) def _project_to_umap(pos): dists, indices = tree.query(pos, k=min(15, n_cells - 1)) dX = X[indices] - pos dU = X_umap[indices] - X_umap[indices[0]] try: A, _, _, _ = np.linalg.lstsq(dX, dU, rcond=None) return X_umap[indices[0]] + (pos - X[indices[0]]) @ A except np.linalg.LinAlgError: return X_umap[indices[0]] def _nearest_cluster(pos): _, idx = tree.query(pos, k=1) return str(adata.obs[cluster_key].iloc[idx]) def _integrate(cell_idx, sign=1.0): """Integrate by always stepping from real cell positions.""" current_cell = cell_idx path_pca = [X[current_cell].copy()] clusters_visited = [_nearest_cluster(X[current_cell])] pseudotimes_along = [float(pseudotime[current_cell])] visited_cells = {current_cell} stall_count = 0 snap_dt = max(dt, mean_nn_dist * 0.5) for step in range(n_steps): pos = X[current_cell] pt_current = float(pseudotime[current_cell]) v = _velocity_at(pos, pt_current) v_mag = np.linalg.norm(v) if v_mag < 1e-8: break v_unit = v / v_mag pos_candidate = pos + sign * snap_dt * v_unit k_snap = min(5, n_cells - 1) dists_snap, idx_snap = tree.query(pos_candidate, k=k_snap) best_cell = None best_score = -np.inf for i in range(k_snap): candidate_cell = idx_snap[i] displacement = X[candidate_cell] - X[current_cell] disp_mag = np.linalg.norm(displacement) if disp_mag < 1e-10: continue alignment = np.dot(sign * v_unit, displacement / disp_mag) proximity = 1.0 / (dists_snap[i] + 1e-10) score = alignment * proximity if score > best_score: best_score = score best_cell = candidate_cell if best_cell is None or best_cell == current_cell: pos_candidate2 = pos + sign * snap_dt * 2 * v_unit _, idx2 = tree.query(pos_candidate2, k=k_snap) for i in range(k_snap): if idx2[i] != current_cell: best_cell = idx2[i] break if best_cell is None or best_cell == current_cell: stall_count += 1 if stall_count > 5: break continue stall_count = 0 current_cell = best_cell path_pca.append(X[current_cell].copy()) pseudotimes_along.append(float(pseudotime[current_cell])) clusters_visited.append( str(adata.obs[cluster_key].iloc[current_cell]) ) # Boundary check if evolve_pseudotime: pt_now = float(pseudotime[current_cell]) if pt_now < -0.1 or pt_now > 1.1: break if current_cell in visited_cells: pass visited_cells.add(current_cell) return np.array(path_pca), clusters_visited, pseudotimes_along def _matches_target(clusters_visited, sign): """Check if this trajectory reached the target cluster.""" if target_cluster is None: return True if sign > 0: return clusters_visited[-1] == target_cluster else: return clusters_visited[-1] == target_cluster # ------------------------------------------------------------------ # Integrate trajectories with target filtering # ------------------------------------------------------------------ trajectories_pca = [] trajectories_umap = [] trajectory_metadata = [] directions_to_run = [] if direction == "forward": directions_to_run = [("forward", 1.0)] elif direction == "backward": directions_to_run = [("backward", -1.0)] elif direction == "both": directions_to_run = [("forward", 1.0), ("backward", -1.0)] for dir_name, sign in directions_to_run: collected = 0 attempts = 0 max_attempts = n_trajectories * max_attempts_factor shuffled = rng.permutation(all_candidates) candidate_idx = 0 while collected < n_trajectories and attempts < max_attempts: cell_idx = shuffled[candidate_idx % len(shuffled)] candidate_idx += 1 attempts += 1 path_pca, clusters_visited, pt_along = _integrate(cell_idx, sign) if not _matches_target(clusters_visited, sign): continue if len(path_pca) < 5: continue traj_id = len(trajectory_metadata) trajectories_pca.append(path_pca) if has_umap: path_umap = np.array( [_project_to_umap(p) for p in path_pca] ) trajectories_umap.append(path_umap) if sign > 0: origin_cluster = clusters_visited[0] terminal_cluster = clusters_visited[-1] else: origin_cluster = clusters_visited[-1] terminal_cluster = clusters_visited[0] trajectory_metadata.append({ "id": traj_id, "start_cell": int(cell_idx), "start_cluster": str( adata.obs[cluster_key].iloc[cell_idx] ), "direction": dir_name, "n_steps": len(path_pca), "clusters_visited": clusters_visited, "pseudotime_along": pt_along, "origin_cluster": origin_cluster, "terminal_cluster": terminal_cluster, }) collected += 1 if target_cluster is not None: print( f" {dir_name.capitalize()}: found " f"{collected}/{n_trajectories} " f"trajectories reaching '{target_cluster}' " f"({attempts} attempts)" ) adata.uns["velot_trajectories"] = { "paths_pca": trajectories_pca, "paths_umap": trajectories_umap if has_umap else [], "metadata": trajectory_metadata, "start_cluster": start_cluster, "target_cluster": target_cluster, "start_pseudotime": start_pseudotime, "end_pseudotime": end_pseudotime, "direction": direction, "n_trajectories": len(trajectory_metadata), "n_steps": n_steps, "step_size": step_size, "used_network": has_network, "evolved_pseudotime": evolve_pseudotime, } # Summary print( f" Total: {len(trajectory_metadata)} trajectories " f"({direction}) from '{start_cluster}'" + (f" to '{target_cluster}'" if target_cluster else "") ) if has_network: print(f" Using continuous velocity field (neural network)") for dir_name, _ in directions_to_run: dir_meta = [ m for m in trajectory_metadata if m["direction"] == dir_name ] if not dir_meta: continue if dir_name == "forward": endpoints = [m["terminal_cluster"] for m in dir_meta] label = "Terminal" else: endpoints = [m["origin_cluster"] for m in dir_meta] label = "Origin" unique, counts = np.unique(endpoints, return_counts=True) summary = dict(zip(unique, counts)) print(f" {dir_name.capitalize()}{label} clusters: {summary}") return adata
[docs] def simulate_flow( adata: AnnData, n_particles: int = 200, source_cluster: Optional[str] = None, source_pseudotime_min: float = 0.0, source_pseudotime_max: float = 0.1, source_x_lim: tuple = (-np.inf, np.inf), source_y_lim: tuple = (-np.inf, np.inf), n_steps: int = 300, step_size: float = 0.05, diffusion: float = 0.5, basis: str = "X_pca", cluster_key: str = "clusters", noise_scale: float = 0.0, random_state: int = 42, ) -> AnnData: """ Simulate a flow of particles through the learned velocity field. Drops particles at the "top of the river" (low pseudotime or a source cluster) and integrates them forward through the velocity field with optional stochastic diffusion. The integration follows a stochastic differential equation: x_{n+1} = x_n + dt * v(x_n, τ_n) + sqrt(2 * D * dt) * η_n where D is the diffusion coefficient and η is Gaussian noise. The diffusion allows particles starting from the same region to explore different branches at bifurcation points, producing realistic fate distributions. Parameters ---------- adata Must have the smoothing network trained. n_particles Number of particles to simulate. source_cluster Cluster to seed particles from. If None, uses cells with lowest pseudotime. source_pseudotime_max If source_cluster is None, seed from cells with pseudotime below this value. n_steps Maximum number of integration steps. step_size Euler step size as fraction of mean velocity magnitude. diffusion Diffusion coefficient controlling stochasticity. 0.0 = deterministic (all trajectories converge). 0.1-0.5 = mild stochasticity (some branching exploration). 1.0+ = strong stochasticity (wide exploration). The noise is scaled relative to the local velocity magnitude so it is adaptive. basis PCA embedding key. cluster_key Cluster annotation column. noise_scale Gaussian noise added to INITIAL positions only. Separate from diffusion which is added at every step. random_state Random seed. Returns ------- adata with ``adata.uns['velot_flow']``. Examples -------- Deterministic flow (trajectories will converge):: velot.tl.simulate_flow(adata, diffusion=0.0, ...) Stochastic flow (trajectories explore branches):: velot.tl.simulate_flow(adata, diffusion=0.3, ...) Strong diffusion (wide exploration):: velot.tl.simulate_flow(adata, diffusion=1.0, ...) """ _check_fields(adata, obsm_keys=[basis], obs_keys=["pseudotime"]) if "velot_smoothing" not in adata.uns or "network" not in adata.uns["velot_smoothing"]: raise ValueError( "Smoothing network not found. Run velot.tl.velocity(smooth=True) first." ) X = adata.obsm[basis] V = adata.obsm["velot_velocity"] pseudotime = adata.obs["pseudotime"].values n_cells, dim = X.shape net = adata.uns["velot_smoothing"]["network"] use_pt = adata.uns["velot_smoothing"]["use_pseudotime"] net.eval() tree = cKDTree(X) rng = np.random.RandomState(random_state) # Manifold boundary sample_dists = tree.query(X[:min(200, n_cells)], k=2)[0][:, 1] mean_nn_dist = np.mean(sample_dists) # UMAP has_umap = "X_umap" in adata.obsm if has_umap: X_umap = adata.obsm["X_umap"] # Replace step size calibration: sample_idx = np.random.choice(n_cells, min(500, n_cells), replace=False) mean_nn_dist = np.mean(tree.query(X[sample_idx], k=2)[0][:, 1]) dt = step_size * mean_nn_dist diffusion_scale = diffusion * mean_nn_dist * np.sqrt(dt) boundary_threshold = 10 * mean_nn_dist # ------------------------------------------------------------------ # Seed particles # ------------------------------------------------------------------ if source_cluster is not None: mask = adata.obs[cluster_key].astype(str) == str(source_cluster) candidates = np.where(mask)[0] if len(candidates) == 0: raise ValueError(f"Cluster '{source_cluster}' not found.") selected = rng.choice( candidates, size=min(n_particles, len(candidates)), replace=(n_particles > len(candidates)), ) source_label = source_cluster elif (source_pseudotime_min is not None) & (source_pseudotime_max is not None): mask = (source_pseudotime_min <= pseudotime) & (pseudotime <= source_pseudotime_max) candidates = np.where(mask)[0] if len(candidates) == 0: candidates = np.argsort(pseudotime)[:max(10, n_particles)] selected = rng.choice(candidates, size=min(n_particles, len(candidates)), replace=True) source_label = f"{source_pseudotime_min} ≤ pseudotime ≤ {source_pseudotime_max}" else: mask = ((X[:, 0] >= source_x_lim[0]) & (X[:, 0] <= source_x_lim[1]) & (X[:, 1] >= source_y_lim[0]) & (X[:, 1] <= source_y_lim[1])) candidates = np.where(mask)[0] if len(candidates) == 0: candidates = np.argsort(pseudotime)[:max(10, n_particles)] selected = rng.choice(candidates, size=min(n_particles, len(candidates)), replace=True) source_label = f"x lim $\in$ {source_x_lim}; x lim $\in$ {source_y_lim}" # Initial positions initial_positions = X[selected].copy() if noise_scale > 0: initial_positions += rng.normal(0, noise_scale, initial_positions.shape) initial_pseudotimes = pseudotime[selected].copy() print(f" Seeding {len(selected)} particles from {source_label}") print(f" Diffusion coefficient: {diffusion} " f"(scale={diffusion_scale:.4f} per step)") # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _velocity_at(pos, pt_value): pos_t = torch.tensor( pos.reshape(1, -1).astype(np.float32), device=DEVICE ) pt_t = torch.tensor( [[pt_value]], dtype=torch.float32, device=DEVICE ) with torch.no_grad(): v = net(pos_t, pt_t if use_pt else None) return v.cpu().numpy().flatten() def _estimate_pseudotime_at(pos): dists, indices = tree.query(pos, k=min(10, n_cells - 1)) weights = 1.0 / (dists + 1e-10) weights = weights / weights.sum() return float(np.sum(pseudotime[indices] * weights)) def _nearest_cluster(pos): _, idx = tree.query(pos, k=1) return str(adata.obs[cluster_key].iloc[idx]) def _project_to_umap(pos): dists, indices = tree.query(pos, k=min(15, n_cells - 1)) dX = X[indices] - pos dU = X_umap[indices] - X_umap[indices[0]] try: A, _, _, _ = np.linalg.lstsq(dX, dU, rcond=None) return X_umap[indices[0]] + (pos - X[indices[0]]) @ A except np.linalg.LinAlgError: return X_umap[indices[0]] # ------------------------------------------------------------------ # Integrate all particles (SDE) # ------------------------------------------------------------------ all_paths = [] all_paths_umap = [] all_pseudotimes = [] all_clusters = [] for p_idx in range(len(selected)): pos = initial_positions[p_idx].copy() pt_current = float(initial_pseudotimes[p_idx]) path = [pos.copy()] pt_along = [pt_current] cl_along = [_nearest_cluster(pos)] for step in range(n_steps): # Deterministic drift v = _velocity_at(pos, pt_current) v_mag = np.linalg.norm(v) if v_mag < 1e-8: break # Stochastic diffusion # Scale noise by local velocity magnitude so that # fast-moving regions get proportionally less noise # and bifurcation points (where velocity is ambiguous) # get relatively more exploration if diffusion > 0: noise = rng.normal(0, 1, dim) # Adaptive: noise is perpendicular-biased # Project out the velocity direction to get noise # mostly perpendicular to the flow v_unit = v / (v_mag + 1e-10) parallel = np.dot(noise, v_unit) * v_unit perpendicular = noise - parallel # Keep 80% perpendicular, 20% parallel # This lets particles spread across branches # without fighting the main flow direction noise = 0.2 * parallel + 0.8 * perpendicular noise = noise * diffusion_scale else: noise = 0.0 # SDE step: drift + diffusion v_mag = np.linalg.norm(v) v_unit = v / (v_mag + 1e-10) pos = pos + dt * v_unit + noise path.append(pos.copy()) # Update pseudotime pt_current = _estimate_pseudotime_at(pos) pt_along.append(pt_current) cl_along.append(_nearest_cluster(pos)) # Stop conditions dist_to_nearest, _ = tree.query(pos, k=1) if dist_to_nearest > boundary_threshold: break if pt_current > 0.99: break path = np.array(path) all_paths.append(path) all_pseudotimes.append(pt_along) all_clusters.append(cl_along) if has_umap: path_umap = np.array([_project_to_umap(p) for p in path]) all_paths_umap.append(path_umap) # ------------------------------------------------------------------ # Summary statistics # ------------------------------------------------------------------ terminal_clusters = [cl[-1] for cl in all_clusters] unique_terminals, terminal_counts = np.unique( terminal_clusters, return_counts=True ) terminal_fractions = dict(zip( unique_terminals, (terminal_counts / terminal_counts.sum()).round(3), )) final_pseudotimes = [pt[-1] for pt in all_pseudotimes] mean_steps = np.mean([len(p) for p in all_paths]) adata.uns["velot_flow"] = { "particles": all_paths, "particles_umap": all_paths_umap if has_umap else [], "pseudotime_along": all_pseudotimes, "cluster_along": all_clusters, "initial_cells": selected, "source": source_label, "n_particles": len(selected), "diffusion": diffusion, "metadata": { "terminal_fractions": terminal_fractions, "mean_final_pseudotime": float(np.mean(final_pseudotimes)), "mean_trajectory_length": float(mean_steps), }, } print(f" Flow simulation complete:") print(f" Mean trajectory length: {mean_steps:.0f} steps") print(f" Mean final pseudotime: {np.mean(final_pseudotimes):.3f}") print(f" Terminal fate distribution:") for cl, frac in sorted(terminal_fractions.items(), key=lambda x: -x[1]): count = int(terminal_counts[unique_terminals == cl][0]) print(f" {cl}: {frac:.1%} ({count} particles)") return adata