Source code for velot.pl

"""
Visualization functions for the VelOT pipeline.
Follows the scanpy/scvelo convention.

All functions accept ``show=True`` (display immediately) and
``save=None`` (path to save the figure).

Usage::

    import velot

    velot.pl.velocity_stream(adata, color="clusters")
    velot.pl.windows(adata)
    velot.pl.training_curves(adata)
"""

from __future__ import annotations

from pathlib import Path
from typing import Optional, Sequence

import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc
from anndata import AnnData

try:
    import scvelo as scv

    _HAS_SCVELO = True
except ImportError:
    _HAS_SCVELO = False


# =====================================================================
# Main plotting functions
# =====================================================================

[docs] def dataset_overview_simple( adata: AnnData, color: str = "clusters", basis: str = "umap", title: Optional[str] = "", show: bool = True, save: bool = False, save_path: Optional[str] = None, figsize: tuple = (5,5), ax: Optional[plt.Axes] = None, inframe: bool = False, ) -> plt.Figure: """ Overview of the dataset: embedding colored by clusters and pseudotime. Parameters ---------- adata Annotated data matrix with UMAP and pseudotime computed. color Column in adata.obs for cluster coloring. basis Embedding to plot (``'umap'`` or ``'pca'``). show Whether to display the figure. save File path to save the figure. None to skip. figsize Figure size. Returns ------- matplotlib Figure. """ owns_figure = ax is None if owns_figure: fig, ax = plt.subplots(figsize=figsize) if inframe: scv.pl.scatter( adata, basis=basis, color=color, ax=ax, show=False, title="", frameon=False, legend_loc="on data" ) else: sc.pl.embedding( adata, basis=basis, color=color, ax=ax, show=False, title="", frameon=False ) # Get existing legend leg = ax.get_legend() if leg is not None: handles = leg.legend_handles labels = [t.get_text() for t in leg.get_texts()] leg.remove() # remove original legend if owns_figure: fig.legend( handles, labels, loc="center left", bbox_to_anchor=(1.02, 0.5), fontsize=14, markerscale=1.5, frameon=False ) else: ax.legend( handles, labels, loc="center left", bbox_to_anchor=(1.02, 0.5), fontsize=14, markerscale=1.5, frameon=False ) ax.set_title(title, fontsize=18, fontfamily="sans serif") add_umap_axis(ax) if owns_figure: _finish(fig, show=show, save=save, save_path=save_path) return fig if not show else None return ax if not show else None
[docs] def dataset_overview( adata: AnnData, color: str = "clusters", basis: str = "umap", title: Optional[str] = "", show: bool = True, save: bool = False, save_path: Optional[str] = None, figsize: tuple = (5,5), inframe: bool = False, out_legend: bool = False, vertical: bool = False ) -> plt.Figure: """ Overview of the dataset: embedding colored by clusters and pseudotime. Parameters ---------- adata Annotated data matrix with UMAP and pseudotime computed. color Column in adata.obs for cluster coloring. basis Embedding to plot (``'umap'`` or ``'pca'``). show Whether to display the figure. save File path to save the figure. None to skip. figsize Figure size. Returns ------- matplotlib Figure. """ if vertical: nrows, ncols = 2, 1 else: nrows, ncols = 1, 2 fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*figsize[0], nrows*figsize[1])) titles = ("", "") if title: titles = (color, "Diffusion Pseudotime") sc.pl.embedding( adata, basis=basis, color=color, ax=axes[0], show=False, title="", frameon=False, legend_loc="on data" ) axes[0].set_title(titles[0], fontsize=18, fontfamily="sans serif") ax = axes[0] add_umap_axis(ax) # Get existing legend leg = ax.get_legend() if leg is not None: handles = leg.legend_handles labels = [t.get_text() for t in leg.get_texts()] leg.remove() # remove original legend # Create new legend (customized) ax.legend( handles, labels, # title="Sample", # or color name loc="lower right", # change position here fontsize=14, # bigger text # title_fontsize=11, markerscale=1.5, # bigger markers frameon=False, ) if "pseudotime" in adata.obs: sc.pl.embedding( adata, basis=basis, color="pseudotime", ax=axes[1], show=False, title="", frameon=False, color_map="viridis", ) else: axes[1].text( 0.5, 0.5, "No pseudotime computed", ha="center", va="center", transform=axes[1].transAxes, ) axes[1].set_title("Pseudotime") axes[1].set_title(titles[1], fontsize=18, fontfamily="sans serif") add_umap_axis(axes[1]) plt.tight_layout() _finish(fig, show=show, save=save, save_path=save_path) return fig if not show else None
[docs] def add_umap_axis(ax, basis='umap', pos=(0, 0), length=0.2, fontsize=10): """ Draw equal-length axis arrows on a scatter plot. The arrows are guaranteed to be the same physical length on screen regardless of the axes aspect ratio. Parameters ---------- ax Matplotlib Axes to draw on. basis ``'umap'`` or ``'pca'`` — controls the labels. pos Position of the arrow origin in axes fraction (0–1). length Arrow length in axes-fraction units (applied to the x-direction; the y-direction is adjusted to match). fontsize Label font size. """ x0, y0 = pos # ------------------------------------------------------------------ # Compute the physical aspect ratio of the axes so both arrows # have the same length in display (screen/paper) space # ------------------------------------------------------------------ fig = ax.figure # Try renderer-based measurement first (most accurate after draw) try: renderer = fig.canvas.get_renderer() bbox = ax.get_window_extent(renderer=renderer) ax_width = bbox.width # pixels ax_height = bbox.height # pixels except (AttributeError, TypeError): # Fallback: compute from figure size and axes position bbox = ax.get_position() fig_w, fig_h = fig.get_size_inches() ax_width = bbox.width * fig_w # inches ax_height = bbox.height * fig_h # inches aspect = ax_width / max(ax_height, 1e-10) # length_x is the reference (in axes fraction) # length_y is adjusted so physical (display) length matches # # Physical length of x-arrow = length_x * ax_width # Physical length of y-arrow = length_y * ax_height # Set equal: length_y * ax_height = length_x * ax_width # → length_y = length_x * (ax_width / ax_height) # → length_y = length_x * aspect length_x = length length_y = length * aspect # Arrows ax.annotate( "", xy=(x0 + length_x, y0), xytext=(x0, y0), xycoords=ax.transAxes, arrowprops=dict(arrowstyle="->", lw=1.5, color="black"), ) ax.annotate( "", xy=(x0, y0 + length_y), xytext=(x0, y0), xycoords=ax.transAxes, arrowprops=dict(arrowstyle="->", lw=1.5, color="black"), ) # Labels if basis == 'umap': texts = ["UMAP1", "UMAP2"] elif basis == 'pca': texts = ["PCA1", "PCA2"] else: texts = [f"{basis.upper()}1", f"{basis.upper()}2"] ax.text( x0 + length_x, y0 - 0.02, texts[0], transform=ax.transAxes, ha="right", va="top", fontsize=fontsize, ) ax.text( x0 - 0.02, y0 + length_y, texts[1], transform=ax.transAxes, ha="right", va="top", rotation=90, fontsize=fontsize, )
[docs] def windows( adata: AnnData, basis: str = "umap", pairs_to_show: Optional[Sequence[int]] = None, max_show: int = 12, ncols: int = 4, point_size: int = 20, title: str = None, pair_title: bool = False, frameon: bool = False, show: bool = True, save: bool = False, save_path: Optional[str] = None, figsize_per_panel: tuple = (3, 3), ) -> plt.Figure: """ Visualize the spatial-temporal windows. Each panel shows one window pair: source cells in blue, target cells in orange. Overlap is shown in green. Parameters ---------- adata Must have ``adata.uns['velot_windows']`` from ``velot.tl.build_windows()``. basis Embedding to plot in. pairs_to_show Specific pair indices to display. For example, ``[10, 13, 14, 15]`` shows only those four pairs. If None, shows the first ``max_show`` pairs. max_show Maximum number of pairs when ``pairs_to_show`` is None. ncols Number of columns in the grid. point_size Scatter point size. show Whether to display. save File path to save. Returns ------- matplotlib Figure. Examples -------- Show first 8 pairs:: velot.pl.windows(adata, max_show=8) Show specific pairs:: velot.pl.windows(adata, pairs_to_show=[10, 13, 14, 15]) """ if "velot_windows" not in adata.uns: raise ValueError("No windows found. Run velot.tl.build_windows() first.") pairs = adata.uns["velot_windows"]["pairs"] n_total = len(pairs) # Determine which pairs to show if pairs_to_show is not None: indices = list(pairs_to_show) # Validate for idx in indices: if idx < 0 or idx >= n_total: raise ValueError( f"Pair index {idx} out of range. " f"Dataset has {n_total} pairs (0 to {n_total - 1})." ) else: indices = list(range(min(max_show, n_total))) n_show = len(indices) nrows = int(np.ceil(n_show / ncols)) coords = adata.obsm.get(f"X_{basis}") if coords is None: raise ValueError(f"Embedding 'X_{basis}' not found in adata.obsm.") figsize = (ncols * figsize_per_panel[0], nrows * figsize_per_panel[1]) fig, axes = plt.subplots(nrows, ncols, figsize=figsize) axes = np.atleast_2d(axes).flatten() for panel_i, pair_idx in enumerate(indices): ax = axes[panel_i] idx_src, idx_tgt = pairs[pair_idx] # Background: all cells in light gray ax.scatter( coords[:, 0], coords[:, 1], s=1, c="lightgray", alpha=0.3, ) # Find overlap overlap = np.intersect1d(idx_src, idx_tgt) src_only = np.setdiff1d(idx_src, overlap) tgt_only = np.setdiff1d(idx_tgt, overlap) # Plot source, target, overlap if len(tgt_only) > 0: ax.scatter( coords[tgt_only, 0], coords[tgt_only, 1], s=point_size, c="tab:orange", alpha=0.6, label="target", ) if len(src_only) > 0: ax.scatter( coords[src_only, 0], coords[src_only, 1], s=point_size, c="tab:blue", alpha=0.6, label="source", ) if len(overlap) > 0: ax.scatter( coords[overlap, 0], coords[overlap, 1], s=point_size, c="tab:green", alpha=0.6, label="overlap", ) if pair_title: ax.set_title(f"Pair {pair_idx}", fontsize=18, fontfamily="sans serif") if not frameon: ax.axis("off") else: ax.set_xticks([]) ax.set_yticks([]) # if panel_i == 0: # ax.legend(fontsize=6, loc="lower right") add_umap_axis(ax) # Turn off unused axes for j in range(n_show, len(axes)): # axes[j].axis("off") fig.delaxes(axes[j]) if title is not None: fig.suptitle( f"Window pairs ({n_show} of {n_total} shown)", fontsize=18, fontweight="bold", ) handles, labels = axes[panel_i].get_legend_handles_labels() axes[panel_i].legend( handles, labels, # title="Sample", fontsize=14, # title_fontsize=8, markerscale=2, loc="lower right", frameon=False ) plt.tight_layout() _finish(fig, show=show, save=save, save_path=save_path) return fig if not show else None
[docs] def window_transport( adata: AnnData, pair_index: int, basis: str = "umap", velocity_basis: str = "X_pca", reg: float = 0.05, lambda_time: float = 1.0, lambda_knn: float = 1.0, show_top_k: Optional[int] = None, padding: float = 0.3, background_alpha: float = 0.4, background_size: int = 15, background_color: str = "lightgray", point_size: int = 60, arrow_alpha: float = 0.4, arrow_width: float = 0.003, arrow_color: str = "black", colorby_weight: bool = True, cluster_key: Optional[str] = None, frameon: bool = False, show: bool = True, save: bool = False, save_path: Optional[str] = None, figsize: tuple = (4,4), ) -> plt.Figure: """ Zoom into a window pair showing OT transport on the full dataset. All cells are shown as context (gray or colored by cluster), with source and target cells highlighted and transport arrows overlaid. The axes are zoomed to the window region. Parameters ---------- adata Must have windows computed. pair_index Index of the window pair to visualize. basis Embedding for visualization. velocity_basis Embedding where OT is computed (for recomputing the plan). reg Sinkhorn regularization. lambda_time Pseudotime backward penalty. lambda_knn Non-neighbor penalty. show_top_k Show arrows only to top-k targets per source cell. None shows all above a threshold. padding Fractional padding around the window for zoom. background_alpha Transparency of background (non-window) cells. background_size Point size of background cells. background_color Color for background cells when ``cluster_key`` is None. point_size Size of source/target cells. arrow_alpha Arrow transparency. arrow_width Arrow line width. arrow_color Arrow color when ``colorby_weight=False``. colorby_weight Color arrows by transport weight. cluster_key If provided, background cells are colored by cluster instead of uniform gray. Gives spatial context. frameon Show axes frame. show Display the figure. save File path to save. Examples -------- :: # With cluster context velot.pl.window_transport(adata, 5, cluster_key="clusters") # Clean, gray background velot.pl.window_transport(adata, 5) # Top-3 connections only velot.pl.window_transport(adata, 5, show_top_k=3) """ import ot as pot if "velot_windows" not in adata.uns: raise ValueError("No windows found. Run velot.tl.build_windows() first.") pairs = adata.uns["velot_windows"]["pairs"] n_total = len(pairs) if pair_index < 0 or pair_index >= n_total: raise ValueError( f"pair_index {pair_index} out of range (0 to {n_total - 1})." ) idx_src, idx_tgt = pairs[pair_index] # Visualization coordinates coords_key = f"X_{basis}" if coords_key not in adata.obsm: raise ValueError(f"'{coords_key}' not found in adata.obsm.") coords = adata.obsm[coords_key] # OT computation coordinates if velocity_basis not in adata.obsm: raise ValueError(f"'{velocity_basis}' not found in adata.obsm.") X = adata.obsm[velocity_basis] pseudotime = adata.obs["pseudotime"].values # ------------------------------------------------------------------ # Recompute transport plan # ------------------------------------------------------------------ X_src = X[idx_src] X_tgt = X[idx_tgt] n_src = X_src.shape[0] n_tgt = X_tgt.shape[0] a = np.ones(n_src, dtype=np.float64) / n_src b = np.ones(n_tgt, dtype=np.float64) / n_tgt C = pot.dist(X_src, X_tgt, metric="sqeuclidean") C_max = C.max() if C_max > 0: C = C / C_max t_src = pseudotime[idx_src] t_tgt = pseudotime[idx_tgt] time_diff = t_src[:, None] - t_tgt[None, :] C[time_diff > 0] += lambda_time if "connectivities" in adata.obsp: knn_adj = adata.obsp["connectivities"] local_adj = knn_adj[idx_src][:, idx_tgt] if hasattr(local_adj, "toarray"): local_adj = local_adj.toarray() C[local_adj == 0] += lambda_knn try: P = pot.sinkhorn(a, b, C, reg=reg, numItermax=300, stopThr=1e-6) except Exception: try: P = pot.emd(a, b, C) except Exception: raise RuntimeError("OT computation failed for this pair.") row_sums = P.sum(axis=1, keepdims=True) row_sums[row_sums == 0] = 1.0 P_norm = P / row_sums # ------------------------------------------------------------------ # Compute zoom region from window cells # ------------------------------------------------------------------ coords_src = coords[idx_src] coords_tgt = coords[idx_tgt] all_window_coords = np.vstack([coords_src, coords_tgt]) x_min, y_min = all_window_coords.min(axis=0) x_max, y_max = all_window_coords.max(axis=0) x_range = max(x_max - x_min, 1e-6) y_range = max(y_max - y_min, 1e-6) xlim = (x_min - padding * x_range, x_max + padding * x_range) ylim = (y_min - padding * y_range, y_max + padding * y_range) # ------------------------------------------------------------------ # Identify which cells are in the zoomed region (for background) # ------------------------------------------------------------------ in_view = ( (coords[:, 0] >= xlim[0]) & (coords[:, 0] <= xlim[1]) & (coords[:, 1] >= ylim[0]) & (coords[:, 1] <= ylim[1]) ) # Cells that are in view but NOT in the window pair all_window_idx = np.union1d(idx_src, idx_tgt) is_window_cell = np.zeros(adata.n_obs, dtype=bool) is_window_cell[all_window_idx] = True bg_mask = in_view & ~is_window_cell # ------------------------------------------------------------------ # Plot # ------------------------------------------------------------------ fig, ax = plt.subplots(figsize=figsize) # Layer 1: Background cells in the zoomed region bg_indices = np.where(bg_mask)[0] if cluster_key is not None and cluster_key in adata.obs: categories = adata.obs[cluster_key].astype("category") cat_names = categories.cat.categories.tolist() codes = categories.cat.codes.values cmap = plt.cm.tab20 colors_array = cmap(np.linspace(0, 1, len(cat_names))) for ci, cat_name in enumerate(cat_names): cat_bg = bg_indices[codes[bg_indices] == ci] if len(cat_bg) > 0: ax.scatter( coords[cat_bg, 0], coords[cat_bg, 1], s=background_size, c=[colors_array[ci]], alpha=background_alpha, label=cat_name, zorder=1, ) else: if len(bg_indices) > 0: ax.scatter( coords[bg_indices, 0], coords[bg_indices, 1], s=background_size, c=background_color, alpha=background_alpha, zorder=1, ) # Layer 2: Transport arrows (behind the highlighted cells) w_max = P_norm.max() if P_norm.max() > 0 else 1.0 for i in range(n_src): src_coord = coords_src[i] if show_top_k is not None: top_indices = np.argsort(P_norm[i])[::-1][:show_top_k] else: threshold = 1.0 / (n_tgt * 5) top_indices = np.where(P_norm[i] > threshold)[0] for j in top_indices: weight = P_norm[i, j] if weight < 1e-10: continue tgt_coord = coords_tgt[j] dx = tgt_coord[0] - src_coord[0] dy = tgt_coord[1] - src_coord[1] if colorby_weight: intensity = min(1.0, weight / (w_max * 0.5)) color = plt.cm.Reds(0.3 + 0.7 * intensity) alpha = max(0.1, min(0.8, arrow_alpha + 0.4 * intensity)) width = arrow_width * (0.5 + 1.5 * intensity) else: color = arrow_color alpha = arrow_alpha width = arrow_width ax.annotate( "", xy=(tgt_coord[0], tgt_coord[1]), xytext=(src_coord[0], src_coord[1]), arrowprops=dict( arrowstyle="-|>", color=color, lw=width * 300, alpha=alpha, shrinkA=2, shrinkB=2, mutation_scale=12, ), zorder=2, ) # Layer 3: Source and target cells on top overlap = np.intersect1d(idx_src, idx_tgt) src_only = np.setdiff1d(idx_src, overlap) tgt_only = np.setdiff1d(idx_tgt, overlap) if len(tgt_only) > 0: ax.scatter( coords[tgt_only, 0], coords[tgt_only, 1], s=point_size, c="tab:orange", alpha=0.85, edgecolors="black", linewidths=0.4, label="target", zorder=4, ) if len(src_only) > 0: ax.scatter( coords[src_only, 0], coords[src_only, 1], s=point_size, c="tab:blue", alpha=0.85, edgecolors="black", linewidths=0.4, label="source", zorder=4, ) if len(overlap) > 0: ax.scatter( coords[overlap, 0], coords[overlap, 1], s=point_size, c="tab:green", alpha=0.85, edgecolors="black", linewidths=0.4, label="overlap", zorder=4, ) # ------------------------------------------------------------------ # Zoom and formatting # ------------------------------------------------------------------ ax.set_xlim(xlim) ax.set_ylim(ylim) # Legend without duplicate labels handles, labels = ax.get_legend_handles_labels() by_label = dict(zip(labels, handles)) ax.legend( by_label.values(), by_label.keys(), fontsize=20, loc="lower right", frameon=False, #framealpha=0.9, ) # ax.set_title( # f"Window pair {pair_index} — OT transport\n" # f"src={n_src} cells (blue), tgt={n_tgt} cells (orange), " # f"reg={reg}", # fontsize=11, # ) ax.set_title( f"OT transport for pair {pair_index}", fontsize=26, fontfamily="sans serif" ) if not frameon: ax.axis("off") plt.tight_layout() _finish(fig, show=show, save=save, save_path=save_path) return fig if not show else None
from typing import Optional, Sequence from anndata import AnnData import scvelo as scv
[docs] def velocity_stream( adata: AnnData, color: str = "clusters", basis: str = "umap", velocity_key: str = "velot_velocity", title: Optional[str] = "VelOT velocity", show: bool = True, save: bool = False, save_path: Optional[str] = None, figsize: tuple = (4,4), ax: Optional[plt.Axes] = None, **kwargs, ) -> plt.Figure: """ Stream plot of the velocity field. Uses scVelo's stream plotting if available, falls back to a quiver plot otherwise. Parameters ---------- adata Must have ``adata.obsm['velocity_umap']``. color Column in adata.obs for cell coloring. basis Embedding basis. velocity_key Key in adata.obsm containing the velocity to plot. title Plot title. Defaults to "VelOT velocity". show Whether to display. save File path to save. **kwargs Passed to scvelo.pl.velocity_embedding_stream or quiver. Returns ------- matplotlib Figure. """ owns_figure = ax is None if owns_figure: fig, ax = plt.subplots(figsize=figsize) scv.pl.velocity_embedding_stream( adata, basis=basis, vkey=velocity_key, color=color, title="", ax=ax, show=False, **kwargs, ) if title is not None: ax.set_title(title, fontsize=18, fontfamily="sans serif") add_umap_axis(ax) if owns_figure: _finish(fig, show=show, save=save, save_path=save_path) return fig if not show else None return ax if not show else None
[docs] def velocity_quiver( adata: AnnData, color: str = "clusters", basis: str = "umap", velocity_key: str = "velocity_umap", title: Optional[str] = None, spot_size=50, arrow_scale: float = 1.0, subsample: Optional[int] = None, arrow_color: str = "black", arrow_alpha: float = 0.7, normalize_arrows: bool = False, show: bool = True, save: bool = False, save_path: Optional[str] = None, figsize: tuple = (5,5), ax: Optional[plt.Axes] = None, **scatter_kwargs, ) -> plt.Figure: """ Quiver plot of velocity arrows overlaid on a scVelo/scanpy scatter. Uses scVelo's scatter as the base plot for consistent styling (colors, legends, layout). Falls back to scanpy, then to plain matplotlib if neither is available. Parameters ---------- adata Must have embedding and velocity computed. color Column in ``adata.obs`` for coloring cells (passed to the base scatter plot). basis Embedding to use: ``"umap"``, ``"pca"``, etc. Looks for ``adata.obsm[f"X_{basis}"]``. velocity_key Key in ``adata.obsm`` with velocity vectors **in the same coordinate system as the embedding**. For UMAP basis, use projected UMAP velocities. For PCA basis, use PCA velocities (only the first 2 components are plotted). title Plot title. arrow_scale Multiplier for arrow length. Increase to make arrows longer, decrease to shorten them. subsample Number of cells to show arrows for. ``None`` for all cells. Recommended ~500-1000 for readability. arrow_color Color of the quiver arrows. Can be any matplotlib color. arrow_alpha Transparency of arrows (0–1). normalize_arrows If True, all arrows have the same length (unit vectors), showing direction only. Useful when velocity magnitudes vary wildly. show Whether to display the plot. save Path to save the figure. figsize Figure size in inches. **scatter_kwargs Extra keyword arguments passed to the base scatter plot (e.g., ``legend_loc``, ``size``, ``alpha``, ``palette``). Returns ------- matplotlib Figure. Examples -------- Basic quiver on UMAP:: velot.pl.velocity_quiver(adata, velocity_key="velocity_umap") Quiver on PCA (first 2 components), subsampled:: velot.pl.velocity_quiver( adata, basis="pca", velocity_key="velot_velocity", subsample=500, ) Normalized arrows colored red:: velot.pl.velocity_quiver( adata, normalize_arrows=True, arrow_color="red", arrow_scale=0.5, subsample=800, ) """ coords_key = f"X_{basis}" if coords_key not in adata.obsm: raise ValueError(f"'{coords_key}' not found in adata.obsm.") if velocity_key not in adata.obsm: raise ValueError(f"'{velocity_key}' not found in adata.obsm.") coords = adata.obsm[coords_key][:, :2] # always first 2 dims for plot vel = adata.obsm[velocity_key] # Only use first 2 dimensions of velocity for 2D plot if vel.shape[1] > 2: vel_2d = vel[:, :2] else: vel_2d = vel plot_title = title if title == "auto": plot_title = f"VelOT velocity (quiver — {basis.upper()})" owns_figure = ax is None if owns_figure: fig, ax = plt.subplots(figsize=figsize) scv.pl.scatter( adata, basis=basis, color=color, size=spot_size, title="", frameon=False, legend_loc="on data", figsize=figsize, show=False, ax=ax, **scatter_kwargs, ) ax.set_title(plot_title, fontsize=18, fontfamily="sans serif") # ------------------------------------------------------------------ # Subsample for readability # ------------------------------------------------------------------ if subsample is not None and subsample < adata.n_obs: rng = np.random.RandomState(42) idx = rng.choice(adata.n_obs, subsample, replace=False) else: idx = np.arange(adata.n_obs) # ------------------------------------------------------------------ # Velocity arrows # ------------------------------------------------------------------ v = vel_2d[idx].copy() if normalize_arrows: norms = np.linalg.norm(v, axis=1, keepdims=True) v = v / (norms + 1e-10) # Auto-scale: make arrows visible relative to the embedding range # The user's arrow_scale multiplies this base scale coord_range = np.ptp(coords, axis=0).mean() # mean span of embedding vel_mag = np.linalg.norm(v, axis=1).mean() if vel_mag > 1e-10: # Base: arrows span ~3% of the plot range on average auto_scale = (coord_range * 0.03) / vel_mag else: auto_scale = 1.0 scale_factor = auto_scale * arrow_scale ax.quiver( coords[idx, 0], coords[idx, 1], v[:, 0] * scale_factor, v[:, 1] * scale_factor, angles="xy", scale_units="xy", scale=1, width=0.002 * coord_range / 20, # scale width to plot range headwidth=4, headlength=5, color=arrow_color, alpha=arrow_alpha, zorder=10, ) add_umap_axis(ax, basis=basis) plt.tight_layout() if owns_figure: _finish(fig, show=show, save=save, save_path=save_path) return fig if not show else None return ax if not show else None
[docs] def confidence( adata: AnnData, basis: str = "umap", show: bool = True, save: bool = False, save_path: Optional[str] = None, figsize: tuple = (7, 6), ) -> plt.Figure: """ Plot per-cell velocity confidence (how many windows contributed). Low confidence regions are where the OT velocity is uncertain or missing, and the smoothing network is interpolating. Returns ------- matplotlib Figure. """ if "velot_confidence" not in adata.obs: raise ValueError("No confidence data. Run velot.tl.compute_ot_velocity() first.") fig, ax = plt.subplots(figsize=figsize) sc.pl.embedding( adata, basis=basis, color="velot_confidence", ax=ax, show=False, title="VelOT confidence", color_map="YlOrRd", frameon=False, ) plt.tight_layout() _finish(fig, show=show, save=save, save_path=save_path) return fig if not show else None
[docs] def training_curves( adata: AnnData, vertical: bool=False, show: bool = True, save: bool = False, save_path: Optional[str] = None, figsize: tuple = (10, 4), ) -> plt.Figure: """ Plot training loss curves from the smoothing step. Returns ------- matplotlib Figure. """ if "velot_smoothing" not in adata.uns: raise ValueError("No smoothing data. Run velot.tl.smooth_velocity() first.") info = adata.uns["velot_smoothing"] if vertical: fig, axes = plt.subplots(2, 1, figsize=figsize, sharex=True) else: fig, axes = plt.subplots(1, 2, figsize=figsize, sharey=True) axes[0].plot(info["losses_regression"], color="tab:blue", alpha=0.8) axes[0].set_title("Regression loss", fontsize=16) axes[0].set_xlabel("Epoch") axes[0].set_ylabel("Loss") axes[0].set_yscale("log") axes[1].plot(info["losses_smoothness"], color="tab:orange", alpha=0.8) # axes[1].set_title(f"Smoothness loss (λ={info['lambda_smooth']})") axes[1].set_title(f"Smoothness loss", fontsize=16) axes[1].set_xlabel("Epoch") axes[1].set_ylabel("Loss") axes[1].set_yscale("log") # fig.suptitle("VelOT smoothing training", fontsize=13, fontweight="bold") plt.tight_layout() _finish(fig, show=show, save=save, save_path=save_path) return fig if not show else None
[docs] def training_curves_single( adata: AnnData, show: bool = True, save: bool = False, save_path: Optional[str] = None, figsize: tuple = (7,7), ) -> plt.Figure: """ Plot training loss curves from the smoothing step. Returns ------- matplotlib Figure. """ if "velot_smoothing" not in adata.uns: raise ValueError("No smoothing data. Run velot.tl.smooth_velocity() first.") info = adata.uns["velot_smoothing"] fig, axes = plt.subplots(figsize=figsize) losses = ["losses_regression", "losses_smoothness", "losses_curl"] titles = ["Regression", "Smoothness", "Curl"] for loss,title in zip(losses,titles): axes.plot(info[loss], alpha=0.8, label=title) axes.set_xlabel("Epoch") axes.set_ylabel("Loss") axes.set_yscale("log") axes.legend(loc="lower right") _finish(fig, show=show, save=save, save_path=save_path) return fig if not show else None
[docs] def spatial_clusters( adata: AnnData, basis: str = "umap", show: bool = True, save: bool = False, save_path: Optional[str] = None, figsize: tuple = (7, 6), ) -> plt.Figure: """ Visualize the spatial clusters used for windowing. Returns ------- matplotlib Figure. """ if "velot_spatial_cluster" not in adata.obs: raise ValueError( "No spatial clusters. Run velot.tl.build_windows() first." ) fig, ax = plt.subplots(figsize=figsize) sc.pl.embedding( adata, basis=basis, color="velot_spatial_cluster", ax=ax, show=False, title="VelOT spatial clusters", frameon=False, legend_loc="on data", ) plt.tight_layout() _finish(fig, show=show, save=save, save_path=save_path) return fig if not show else None
[docs] def velocity_comparison( adata: AnnData, velocity_keys: Sequence[str], titles: Optional[Sequence[str]] = None, color: str = "clusters", basis: str = "umap", show: bool = True, save: bool = False, save_path: Optional[str] = None, figsize_per_panel: tuple = (6, 5), ) -> plt.Figure: """ Side-by-side stream plots comparing multiple velocity fields. Parameters ---------- adata Annotated data matrix. velocity_keys List of keys in adata.obsm with velocity fields to compare. titles Optional titles for each panel. Returns ------- matplotlib Figure. Example ------- :: # Compare raw vs smoothed velocity velot.pl.velocity_comparison( adata, velocity_keys=["velot_velocity_raw_umap", "velocity_umap"], titles=["Raw OT", "Smoothed"], ) """ if not _HAS_SCVELO: raise ImportError( "velocity_comparison requires scVelo for stream plots. " "Install with: pip install scvelo" ) n = len(velocity_keys) titles = titles or velocity_keys figsize = (figsize_per_panel[0] * n, figsize_per_panel[1]) fig, axes = plt.subplots(1, n, figsize=figsize) if n == 1: axes = [axes] backup = adata.obsm.get("velocity_umap", None) for i, (vkey, title) in enumerate(zip(velocity_keys, titles)): if vkey not in adata.obsm: axes[i].text( 0.5, 0.5, f"'{vkey}' not found", ha="center", va="center", transform=axes[i].transAxes, ) continue adata.obsm["velocity_umap"] = adata.obsm[vkey] scv.pl.velocity_embedding_stream( adata, basis=basis, color=color, title=title, ax=axes[i], show=False, ) # Restore original if backup is not None: adata.obsm["velocity_umap"] = backup plt.tight_layout() _finish(fig, show=show, save=save, save_path=save_path) return fig if not show else None
# ===================================================================== # Internal helpers # ===================================================================== def _finish(fig: plt.Figure, show: bool, save: bool, save_path: Optional[str]): """Handle show/save logic for all plot functions.""" if save: # Fallback if they say save=True but forget to give a path if save_path is None: save_path = "velot_figure.png" print(f"Warning: save=True but no save_path provided. Saving to {save_path}") path = Path(save_path) path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(path, dpi=300, bbox_inches="tight") if show: plt.show() plt.close()
[docs] def gridsearch_results( df, metric: str = "iccoh_mean", param_x: Optional[str] = None, param_hue: Optional[str] = None, show: bool = True, save: bool = False, save_path: Optional[str] = None, figsize: tuple = (12, 5), ) -> plt.Figure: """ Visualize grid search results. Parameters ---------- df DataFrame returned by ``velot.tl.gridsearch()``. metric Column to plot as the y-axis. param_x Parameter for x-axis grouping. If None, auto-detected. param_hue Parameter for color grouping. If None, auto-detected. show Whether to display. save File path to save. Returns ------- matplotlib Figure. Example ------- :: results = velot.tl.gridsearch(adata, param_grid, ...) velot.pl.gridsearch_results(results, metric="cbdir_mean") """ # Auto-detect parameters (columns that aren't metrics) non_param_cols = { "iccoh_mean", "cbdir_mean", "elapsed_seconds", "status", } param_cols = [ c for c in df.columns if c not in non_param_cols and not c.startswith(("iccoh_", "cbdir_")) ] if not param_cols: raise ValueError("No parameter columns detected in DataFrame.") if param_x is None: param_x = param_cols[0] if param_hue is None and len(param_cols) > 1: param_hue = param_cols[1] # Filter to successful runs df_ok = df[df["status"] == "ok"].copy() if len(df_ok) == 0: raise ValueError("No successful runs to plot.") fig, axes = plt.subplots(1, 2, figsize=figsize) # Panel 1: metric vs param_x, colored by param_hue ax = axes[0] if param_hue is not None: hue_values = sorted(df_ok[param_hue].unique()) colors = plt.cm.tab10(np.linspace(0, 1, len(hue_values))) for hue_val, color in zip(hue_values, colors): mask = df_ok[param_hue] == hue_val subset = df_ok[mask].sort_values(param_x) ax.plot( subset[param_x].astype(str), subset[metric], "o-", label=f"{param_hue}={hue_val}", color=color, markersize=8, alpha=0.8, ) ax.legend(fontsize=8) else: df_sorted = df_ok.sort_values(param_x) ax.plot( df_sorted[param_x].astype(str), df_sorted[metric], "o-", color="tab:blue", markersize=8, ) ax.set_xlabel(param_x) ax.set_ylabel(metric) ax.set_title(f"{metric} vs {param_x}") ax.tick_params(axis="x", rotation=45) # Panel 2: summary bar chart of top 10 runs ax = axes[1] top = df_ok.nlargest(min(10, len(df_ok)), metric) labels = [] for _, row in top.iterrows(): parts = [f"{p}={row[p]}" for p in param_cols] labels.append("\n".join(parts)) bars = ax.barh( range(len(top)), top[metric].values, color="tab:green", alpha=0.7, ) ax.set_yticks(range(len(top))) ax.set_yticklabels(labels, fontsize=7) ax.set_xlabel(metric) ax.set_title(f"Top {len(top)} configurations") ax.invert_yaxis() fig.suptitle("VelOT Grid Search Results", fontsize=14, fontweight="bold") plt.tight_layout() _finish(fig, show=show, save=save, save_path=save_path) return fig if not show else None
[docs] def trajectories( adata: AnnData, trajectory_ids: Optional[Sequence[int]] = None, color: str = "clusters", basis: str = "umap", line_color: str = "black", line_alpha: float = 0.7, line_width: float = 1, line_style: str = "--", start_marker: str = "o", start_size: int = 50, start_color: str = "gray", end_marker: str = "o", end_size: int = 50, end_color: str = "black", arrow_frequency: int = 30, arrow_size: float = 10, show_start: bool = True, show_end: bool = True, show_arrows: bool = True, frameon: bool = False, title: Optional[str] = None, ax: Optional[plt.Axes] = None, show: bool = True, save: bool = False, save_path: Optional[str] = None, figsize: tuple = (5,5), **scanpy_kwargs, ) -> plt.Axes: """ Plot cell trajectories on top of a scanpy embedding plot. Parameters ---------- adata Must have ``adata.uns['velot_trajectories']``. trajectory_ids List of trajectory IDs to plot. If None, plots all. IDs are stored in each trajectory's metadata ``"id"`` field. Use this to select specific trajectories after inspecting the metadata. color Column passed to ``sc.pl.embedding``. basis Embedding basis. line_color Color of trajectory lines. line_alpha Line transparency. line_width Line width. line_style ``"-"`` solid, ``"--"`` dashed, ``":"`` dotted. start_marker, start_size, start_color Start point appearance. end_marker, end_size, end_color End point appearance. arrow_frequency Arrow every N steps. arrow_size Arrowhead size. show_start, show_end, show_arrows Toggle visual elements. frameon Show axes frame. title Plot title. ax Pre-existing axes. If None, creates new figure with scatter background. If provided, draws scatter background and trajectories on it — useful for subplot grids. show Display the plot. Ignored when ``ax`` is provided (caller controls display). save Save path. Ignored when ``ax`` is provided. **scanpy_kwargs Passed to ``scv.pl.scatter`` for the background. Returns ------- matplotlib Axes. Examples -------- Plot all trajectories:: velot.pl.trajectories(adata) Plot specific trajectories by ID:: velot.pl.trajectories(adata, trajectory_ids=[0, 3, 7]) Grid of individual trajectories:: n = 10 cols = 5 rows = (n + cols - 1) // cols fig, axes = plt.subplots(rows, cols, figsize=(5*cols, 5*rows)) for i in range(n): velot.pl.trajectories( adata, trajectory_ids=[i], ax=axes.flat[i], show=False, title=f"Trajectory {i}", ) for j in range(n, len(axes.flat)): axes.flat[j].axis("off") plt.tight_layout() plt.show() Inspect metadata to choose which to plot:: meta = adata.uns["velot_trajectories"]["metadata"] for m in meta: print(f"ID {m['id']}: {m['origin_cluster']} → " f"{m['terminal_cluster']} ({m['n_steps']} steps)") velot.pl.trajectories(adata, trajectory_ids=[2, 5]) """ if "velot_trajectories" not in adata.uns: raise ValueError( "No trajectories found. Run velot.tl.compute_trajectories() first." ) traj_data = adata.uns["velot_trajectories"] all_metadata = traj_data["metadata"] if basis == "umap" and len(traj_data.get("paths_umap", [])) > 0: all_paths = traj_data["paths_umap"] elif basis == "pca" and len(traj_data.get("paths_pca", [])) > 0: all_paths = [p[:, :2] for p in traj_data["paths_pca"]] else: raise ValueError( f"No trajectory paths available for basis='{basis}'. " f"Available: paths_umap={len(traj_data.get('paths_umap', []))}, " f"paths_pca={len(traj_data.get('paths_pca', []))}" ) if not all_paths: raise ValueError("No UMAP-projected trajectories available.") # Filter by trajectory_ids if specified if trajectory_ids is not None: id_set = set(trajectory_ids) selected = [ (path, meta) for path, meta in zip(all_paths, all_metadata) if meta["id"] in id_set ] if not selected: available_ids = [m["id"] for m in all_metadata] raise ValueError( f"No trajectories found with IDs {trajectory_ids}. " f"Available IDs: {available_ids}" ) paths = [s[0] for s in selected] metadata = [s[1] for s in selected] else: paths = all_paths metadata = all_metadata # ------------------------------------------------------------------ # Background scatter — always drawn, whether ax is provided or not # ------------------------------------------------------------------ owns_figure = ax is None if owns_figure: fig, ax = plt.subplots(figsize=figsize) # Draw scatter on the axes (provided or newly created) try: scv.pl.scatter( adata, basis=basis, color=color, ax=ax, show=False, frameon=frameon, legend_loc="on data", **scanpy_kwargs, ) except Exception: print("Trajectories on manual scatter...") # Fallback: manual scatter coords_key = f"X_{basis}" if coords_key in adata.obsm: coords = adata.obsm[coords_key][:, :2] if color in adata.obs: cats = adata.obs[color].astype("category") ax.scatter( coords[:, 0], coords[:, 1], s=10, c=cats.cat.codes.values, cmap="tab20", alpha=0.5, zorder=1, ) else: ax.scatter( coords[:, 0], coords[:, 1], s=10, c="lightgray", alpha=0.5, zorder=1, ) fig = ax.get_figure() # ------------------------------------------------------------------ # Overlay trajectories # ------------------------------------------------------------------ for traj_i, (path, meta) in enumerate(zip(paths, metadata)): if len(path) < 2: continue is_backward = meta["direction"] == "backward" if is_backward: display_path = path[::-1] else: display_path = path # Trajectory line ax.plot( display_path[:, 0], display_path[:, 1], color=line_color, alpha=line_alpha, linewidth=line_width, linestyle=line_style, zorder=10, ) # Start point if show_start: ax.scatter( display_path[0, 0], display_path[0, 1], s=start_size, c=start_color, marker=start_marker, edgecolors="black", linewidths=0.5, zorder=12, ) # End point if show_end: ax.scatter( display_path[-1, 0], display_path[-1, 1], s=end_size, c=end_color, marker=end_marker, edgecolors="white", linewidths=0.5, zorder=12, ) # Direction arrows if show_arrows: for step in range( arrow_frequency, len(display_path) - 1, arrow_frequency ): dx = display_path[step + 1, 0] - display_path[step, 0] dy = display_path[step + 1, 1] - display_path[step, 1] mag = np.sqrt(dx**2 + dy**2) if mag > 1e-10: ax.annotate( "", xy=( display_path[step, 0] + dx * 0.5, display_path[step, 1] + dy * 0.5, ), xytext=( display_path[step, 0], display_path[step, 1], ), arrowprops=dict( arrowstyle="-|>", color=line_color, lw=line_width * 0.6, alpha=line_alpha, mutation_scale=arrow_size, ), zorder=11, ) # ------------------------------------------------------------------ # Title # ------------------------------------------------------------------ direction = traj_data.get("direction", "forward") start = traj_data.get("start_cluster", "?") target = traj_data.get("target_cluster", None) n_shown = len(paths) if title is None: title = "" if title == "auto": parts = [] if direction == "backward": parts.append(f"Backward from '{start}'") else: parts.append(f"Forward from '{start}'") if target: parts.append(f"→ '{target}'") parts.append(f"(n={n_shown})") title = " ".join(parts) ax.set_title(title, fontsize=12) add_umap_axis(ax, basis=basis) # ------------------------------------------------------------------ # Finish — only when we own the figure # ------------------------------------------------------------------ if owns_figure: plt.tight_layout() _finish(fig, show=show, save=save, save_path=save_path) return ax
[docs] def fate_summary( adata: AnnData, show: bool = True, save: bool = False, save_path: Optional[str] = None, figsize: tuple = (8, 4), ) -> plt.Figure: """ Bar chart summarizing trajectory fate or origin. For forward trajectories: shows terminal cluster distribution. For backward trajectories: shows origin cluster distribution. Returns ------- matplotlib Figure. """ if "velot_trajectories" not in adata.uns: raise ValueError( "No trajectories found. Run velot.tl.compute_trajectories() first." ) traj_data = adata.uns["velot_trajectories"] metadata = traj_data["metadata"] start_cluster = traj_data.get("start_cluster", "unknown") direction = traj_data.get("direction", "forward") # Separate by direction if "both" groups = {} for m in metadata: d = m["direction"] if d not in groups: groups[d] = [] groups[d].append(m) n_panels = len(groups) fig, axes = plt.subplots(1, n_panels, figsize=(figsize[0] * n_panels / 2, figsize[1])) if n_panels == 1: axes = [axes] for ax, (dir_name, dir_meta) in zip(axes, groups.items()): if dir_name == "forward": values = [m["terminal_cluster"] for m in dir_meta] ylabel = "Fraction of trajectories" subtitle = f"Forward from '{start_cluster}': terminal fates" else: values = [m["origin_cluster"] for m in dir_meta] ylabel = "Fraction of trajectories" subtitle = f"Backward from '{start_cluster}': traced origins" unique, counts = np.unique(values, return_counts=True) fractions = counts / counts.sum() order = np.argsort(-fractions) unique = unique[order] fractions = fractions[order] counts = counts[order] bars = ax.bar( range(len(unique)), fractions, color=plt.cm.tab20(np.linspace(0, 1, len(unique))), edgecolor="black", linewidth=0.5, ) ax.set_xticks(range(len(unique))) ax.set_xticklabels(unique, rotation=45, ha="right", fontsize=9) ax.set_ylabel(ylabel) ax.set_title(subtitle, fontsize=11) for bar, count, frac in zip(bars, counts, fractions): ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01, f"{count} ({frac:.0%})", ha="center", va="bottom", fontsize=8, ) ax.set_ylim(0, min(1.0, fractions.max() * 1.3)) fig.suptitle("VelOT trajectory fate analysis", fontsize=13, fontweight="bold") plt.tight_layout() _finish(fig, show=show, save=save, save_path=save_path) return fig if not show else None
[docs] def flow_simulation( adata: AnnData, color: str = "clusters", basis: str = "umap", line_alpha: float = 0.3, line_width: float = 0.8, colorby: str = "pseudotime", show_cells: bool = True, cell_alpha: float = 0.2, cell_size: int = 8, show_start: bool = True, show_end: bool = True, frameon: bool = False, title: Optional[str] = None, show: bool = True, save: bool = False, save_path: Optional[str] = None, figsize: tuple = (9, 8), ) -> plt.Figure: """ Visualize the flow simulation — particles flowing through the velocity field like water through a river. Parameters ---------- adata Must have ``adata.uns['velot_flow']`` from ``velot.tl.simulate_flow()``. color Column for background cell coloring. basis Embedding basis. colorby How to color trajectories: - ``"pseudotime"``: color by pseudotime along trajectory (blue=early, red=late) - ``"cluster"``: color by starting cluster - ``"fate"``: color by terminal cluster show_cells Show background cells. cell_alpha Background cell transparency. cell_size Background cell size. show_start Mark starting positions with stars. show_end Mark ending positions with dots. frameon Show axes frame. title Plot title. show Display the plot. save File path to save. Returns ------- matplotlib Figure. """ if "velot_flow" not in adata.uns: raise ValueError( "No flow data. Run velot.tl.simulate_flow() first." ) flow = adata.uns["velot_flow"] paths = flow["particles_umap"] pt_along = flow["pseudotime_along"] cl_along = flow["cluster_along"] if not paths: raise ValueError("No UMAP-projected flow paths available.") coords_key = f"X_{basis}" coords = adata.obsm[coords_key] fig, ax = plt.subplots(figsize=figsize) # Background cells if show_cells and color in adata.obs: categories = adata.obs[color].astype("category") cat_names = categories.cat.categories.tolist() codes = categories.cat.codes.values cmap_bg = plt.cm.tab20 colors_array = cmap_bg(np.linspace(0, 1, len(cat_names))) for ci, cat_name in enumerate(cat_names): mask = codes == ci ax.scatter( coords[mask, 0], coords[mask, 1], s=cell_size, c=[colors_array[ci]], alpha=cell_alpha, label=cat_name, zorder=1, ) elif show_cells: ax.scatter( coords[:, 0], coords[:, 1], s=cell_size, c="lightgray", alpha=cell_alpha, zorder=1, ) # Draw particle trajectories for p_idx, path in enumerate(paths): if len(path) < 2: continue if colorby == "pseudotime": # Color each segment by local pseudotime pts = pt_along[p_idx] for step in range(len(path) - 1): pt_val = pts[min(step, len(pts) - 1)] seg_color = plt.cm.coolwarm(pt_val) ax.plot( path[step:step+2, 0], path[step:step+2, 1], color=seg_color, alpha=line_alpha, linewidth=line_width, zorder=2, ) elif colorby == "fate": terminal = cl_along[p_idx][-1] if color in adata.obs and terminal in cat_names: lc = colors_array[cat_names.index(terminal)] else: lc = "gray" ax.plot( path[:, 0], path[:, 1], color=lc, alpha=line_alpha, linewidth=line_width, zorder=2, ) else: ax.plot( path[:, 0], path[:, 1], color="steelblue", alpha=line_alpha, linewidth=line_width, zorder=2, ) # Start marker if show_start: ax.scatter( path[0, 0], path[0, 1], s=30, c="blue", marker="*", edgecolors="white", linewidths=0.3, alpha=0.7, zorder=4, ) # End marker if show_end: ax.scatter( path[-1, 0], path[-1, 1], s=15, c="red", marker="o", edgecolors="white", linewidths=0.3, alpha=0.7, zorder=4, ) # Legend for background if show_cells and color in adata.obs: ax.legend(fontsize=7, loc="best", markerscale=2, framealpha=0.8) # Title with summary meta = flow["metadata"] source = flow["source"] n = flow["n_particles"] if title is None: title = ( f"Flow simulation: {n} particles from {source}\n" f"Mean final τ = {meta['mean_final_pseudotime']:.2f}, " f"mean {meta['mean_trajectory_length']:.0f} steps" ) ax.set_title(title, fontsize=11) if not frameon: ax.axis("off") plt.tight_layout() _finish(fig, show=show, save=save, save_path=save_path) return fig if not show else None
[docs] def metric_summary( metrics: dict, orientation: str = "horizontal", layout: str = "row", figsize: Optional[tuple] = None, palette_name: str = "Set2", median_color: str = "black", frameon: bool = True, show: bool = True, save: bool = False, save_path: Optional[str] = None, ) -> plt.Figure: """ Box plots of ICCoh and CBDir metrics from a precomputed summary. Parameters ---------- metrics Dictionary returned by ``velot.metrics.summary()``. orientation ``"horizontal"`` or ``"vertical"`` box plots. layout ``"row"`` or ``"column"``. figsize Custom figure size. palette_name Seaborn color palette name. First color for ICCoh, second for CBDir. median_color Median line color. frameon Show axes frame. show Display the figure. save File path to save. Returns ------- matplotlib Figure. """ import seaborn as sns from matplotlib.colors import ListedColormap if "iccoh" not in metrics: raise ValueError( "metrics dict must contain 'iccoh'. " "Pass the output of velot.metrics.summary()." ) palette = sns.color_palette(palette_name, 8) color_iccoh = palette[0] color_cbdir = palette[1] iccoh = metrics["iccoh"] sample_val = next(iter(iccoh.values())) iccoh_is_raw = isinstance(sample_val, (list, np.ndarray)) has_cbdir = "cbdir" in metrics and len(metrics["cbdir"]) > 0 n_panels = 2 if has_cbdir else 1 if layout == "row": nrows, ncols = 1, n_panels default_figsize = (6 * n_panels, 5) else: nrows, ncols = n_panels, 1 default_figsize = (7, 5 * n_panels) figsize = figsize or default_figsize fig, axes = plt.subplots(nrows, ncols, figsize=figsize) if n_panels == 1: axes = [axes] elif isinstance(axes, np.ndarray): axes = axes.flatten() horiz = orientation == "horizontal" medianprops = dict(color=median_color, linewidth=2) whiskerprops = dict(color="black") capprops = dict(color="black") meanprops = dict( marker="o", markerfacecolor="white", markeredgecolor="black", markersize=6, ) def _draw_panel(ax, data_dict, labels_str, panel_color, title, mean_val, median_val, xlabel): data_list = list(data_dict.values()) positions = list(range(len(labels_str))) boxprops = dict(facecolor=panel_color, edgecolor="black", alpha=0.6) bp = ax.boxplot( data_list, positions=positions, vert=not horiz, patch_artist=True, showmeans=True, showfliers=False, boxprops=boxprops, medianprops=medianprops, whiskerprops=whiskerprops, capprops=capprops, meanprops=meanprops, widths=0.6, ) if horiz: ax.set_yticks(positions) ax.set_yticklabels(labels_str, fontsize=12) ax.set_xlim(-1.05, 1.05) ax.set_xlabel(xlabel, fontsize=16) ax.axvline(0, color="gray", linestyle="--", alpha=0.4, linewidth=0.8) ax.text( 0.05, 0.95, f"mean = {mean_val:.2f}\nmedian = {median_val:.2f}", transform=ax.transAxes, fontsize=14, fontweight="bold", va="top", ha="left", ) else: ax.set_xticks(positions) ax.set_xticklabels(labels_str, fontsize=12, rotation=45, ha="right") ax.set_ylim(-1.05, 1.05) ax.set_ylabel(xlabel, fontsize=16) ax.axhline(0, color="gray", linestyle="--", alpha=0.4, linewidth=0.8) ax.text( 0.05, 0.05, f"mean = {mean_val:.2f}\nmedian = {median_val:.2f}", transform=ax.transAxes, fontsize=14, fontweight="bold", va="bottom", ha="left", ) ax.set_title(title, fontsize=16) #, fontweight="bold") # ------------------------------------------------------------------ # Panel 1: ICCoh # ------------------------------------------------------------------ iccoh_labels = sorted(iccoh.keys(), key=str) if iccoh_is_raw: iccoh_ordered = {k: iccoh[k] for k in iccoh_labels} else: iccoh_ordered = {k: [iccoh[k]] for k in iccoh_labels} iccoh_labels_str = [str(k) for k in iccoh_labels] iccoh_mean = metrics.get( "iccoh_mean", np.mean([np.mean(v) for v in iccoh_ordered.values() if len(v) > 0]), ) iccoh_median = metrics.get( "iccoh_median", np.nan, ) _draw_panel( axes[0], iccoh_ordered, iccoh_labels_str, color_iccoh, "", iccoh_mean, iccoh_median, "ICCoh score", ) # ------------------------------------------------------------------ # Panel 2: CBDir # ------------------------------------------------------------------ if has_cbdir: cbdir = metrics["cbdir"] sample_cbdir_val = next(iter(cbdir.values())) cbdir_is_raw = isinstance(sample_cbdir_val, (list, np.ndarray)) cbdir_labels = list(cbdir.keys()) if cbdir_is_raw: cbdir_ordered = {k: cbdir[k] for k in cbdir_labels} else: cbdir_ordered = {k: [cbdir[k]] for k in cbdir_labels} cbdir_labels_str = [] for k in cbdir_labels: if isinstance(k, tuple): cbdir_labels_str.append(f"{k[0]}\n {k[1]}") else: cbdir_labels_str.append(str(k)) cbdir_mean = metrics.get( "cbdir_mean", np.mean([np.mean(v) for v in cbdir_ordered.values() if len(v) > 0]), ) cbdir_median = metrics.get( "cbdir_median", np.nan, ) _draw_panel( axes[1], cbdir_ordered, cbdir_labels_str, color_cbdir, "", cbdir_mean, cbdir_median, "CBDir score", ) # fig.suptitle("VelOT Velocity Metrics", fontsize=24, fontfamily="sans serif") plt.tight_layout() _finish(fig, show=show, save=save, save_path=save_path) return fig if not show else None