Source code for velot.benchmark

import os
from pathlib import Path
import json
import time
from datetime import datetime
from typing import Optional, Dict, List, Any

import numpy as np
import pandas as pd
from scipy.stats import mannwhitneyu
from anndata import AnnData

import matplotlib.pyplot as plt


[docs] class BenchmarkTimer: """Context manager to time pipeline stages.""" def __init__(self): self.stages: Dict[str, float] = {} self._current_stage: Optional[str] = None self._start_time: Optional[float] = None self._total_start: Optional[float] = None self._total_end: Optional[float] = None
[docs] def start(self, stage: str): self._current_stage = stage self._start_time = time.perf_counter() if self._total_start is None: self._total_start = self._start_time
[docs] def stop(self): if self._current_stage is None: return elapsed = time.perf_counter() - self._start_time self.stages[self._current_stage] = elapsed self._total_end = time.perf_counter() self._current_stage = None
def __call__(self, stage: str): return _TimerContext(self, stage) @property def total(self) -> float: if self._total_start is None: return 0.0 end = self._total_end or time.perf_counter() return end - self._total_start
[docs] def summary(self) -> Dict[str, float]: result = dict(self.stages) result["total"] = self.total return result
def __repr__(self): lines = [f" {k}: {v:.2f}s" for k, v in self.stages.items()] lines.append(f" total: {self.total:.2f}s") return "BenchmarkTimer(\n" + "\n".join(lines) + "\n)"
class _TimerContext: def __init__(self, timer: BenchmarkTimer, stage: str): self.timer = timer self.stage = stage def __enter__(self): self.timer.start(self.stage) return self.timer def __exit__(self, *args): self.timer.stop()
[docs] def save_benchmark( adata: AnnData, results: Dict[str, Any], timer: BenchmarkTimer, model_name: str, dataset_name: str, output_dir: str = "benchmark_results", extra_info: Optional[Dict] = None, ) -> str: """ Save benchmark results for one (model, dataset) run. Parameters ---------- results Output of ``velot.metrics.summary()``. timer BenchmarkTimer with timing information. model_name Name of the model (e.g., ``"scvelo_dynamical"``, ``"velot"``). dataset_name Name of the dataset (e.g., ``"pancreas"``). output_dir Directory to save results. extra_info Any additional metadata. Returns ------- Path to the saved JSON file. """ os.makedirs(output_dir, exist_ok=True) os.makedirs(os.path.join(output_dir, "data"), exist_ok=True) record = { "model": model_name, "dataset": dataset_name, "timestamp": datetime.now().isoformat(), "timing": timer.summary(), "extra": extra_info or {}, "metrics": _to_serializable(results), } filename = f"{model_name}_{dataset_name}" filepath = os.path.join(output_dir, f"{filename}.json") with open(filepath, "w") as f: json.dump(record, f, indent=2) adata.write(os.path.join(output_dir, "data", f"{filename}.h5ad")) print(f" Saved: {filepath}") return filepath
[docs] def load_benchmarks( output_dir: str = "benchmark_results", models: Optional[List[str]] = None, datasets: Optional[List[str]] = None, ) -> pd.DataFrame: """ Load benchmark summaries: one row per (model, dataset). Extracts scalar metrics (``*_mean``) and timing information. Returns ------- DataFrame with columns: model, dataset, time_*, metric_mean, ... """ records = [] for filepath in sorted(Path(output_dir).glob("*.json")): with open(filepath) as f: data = json.load(f) if models and data["model"] not in models: continue if datasets and data["dataset"] not in datasets: continue row = { "model": data["model"], "dataset": data["dataset"], "timestamp": data["timestamp"], } # Timing for stage, seconds in data.get("timing", {}).items(): row[f"time_{stage}"] = seconds # Extra info for k, v in data.get("extra", {}).items(): if isinstance(v, (int, float)): row[k] = v # Scalar metrics (e.g., iccoh_mean, cbdir_mean) metrics = data.get("metrics", {}) for k, v in metrics.items(): if isinstance(v, (int, float)): row[k] = v records.append(row) df = pd.DataFrame(records) if len(df) > 0: df = df.sort_values(["dataset", "model"]).reset_index(drop=True) return df
[docs] def load_benchmarks_per_group( output_dir: str = "benchmark_results", models: Optional[List[str]] = None, datasets: Optional[List[str]] = None, metric: Optional[str] = None, ) -> pd.DataFrame: """ Load per-cell metric values for boxplots. Parses the nested structure in the JSON where each metric (e.g., ``"cbdir"``, ``"iccoh"``) contains a dict of group → array-of-values. Groups can be edges (``"Fev+ → Alpha"``) or clusters (``"Alpha"``). Parameters ---------- output_dir Directory containing JSON result files. models Filter to specific models. None for all. datasets Filter to specific datasets. None for all. metric Specific metric to load (e.g., ``"cbdir"``). If None, loads all metrics that contain per-cell arrays. Returns ------- Long-format DataFrame with columns: model, dataset, metric, group, value Each row is one cell's value for one (model, dataset, metric, group). """ records = [] for filepath in sorted(Path(output_dir).glob("*.json")): with open(filepath) as f: data = json.load(f) if models and data["model"] not in models: continue if datasets and data["dataset"] not in datasets: continue model = data["model"] dataset = data["dataset"] metrics = data.get("metrics", {}) for metric_name, metric_val in metrics.items(): # Skip scalars (like iccoh_mean) — we want the dicts of arrays if not isinstance(metric_val, dict): continue # Skip if a specific metric was requested and this isn't it if metric is not None and metric_name != metric: continue # metric_val is like {"Alpha": [0.6, 0.8, ...], "Beta": [...]} for group_name, values in metric_val.items(): if not isinstance(values, list): continue for v in values: if isinstance(v, (int, float)) and not np.isnan(v): records.append({ "model": model, "dataset": dataset, "metric": metric_name, "group": group_name, "value": float(v), }) df = pd.DataFrame(records) if len(df) > 0: df = df.sort_values( ["dataset", "metric", "model", "group"] ).reset_index(drop=True) return df
[docs] def load_benchmarks_per_group_summary( output_dir: str = "benchmark_results", models: Optional[List[str]] = None, datasets: Optional[List[str]] = None, ) -> pd.DataFrame: """ Load one row per (model, dataset, metric, group) with summary stats. Useful for bar charts or compact comparisons where per-cell resolution is not needed. Returns ------- DataFrame with columns: model, dataset, metric, group, mean, median, std, n """ records = [] for filepath in sorted(Path(output_dir).glob("*.json")): with open(filepath) as f: data = json.load(f) if models and data["model"] not in models: continue if datasets and data["dataset"] not in datasets: continue model = data["model"] dataset = data["dataset"] metrics = data.get("metrics", {}) for metric_name, metric_val in metrics.items(): if not isinstance(metric_val, dict): continue for group_name, values in metric_val.items(): if not isinstance(values, list) or len(values) == 0: continue arr = np.array([v for v in values if isinstance(v, (int, float))]) if len(arr) == 0: continue records.append({ "model": model, "dataset": dataset, "metric": metric_name, "group": group_name, "mean": float(np.nanmean(arr)), "median": float(np.nanmedian(arr)), "std": float(np.nanstd(arr)), "n": int(np.sum(~np.isnan(arr))), }) df = pd.DataFrame(records) if len(df) > 0: df = df.sort_values( ["dataset", "metric", "model", "group"] ).reset_index(drop=True) return df
def _to_serializable(obj): """Convert numpy types and tuple keys to JSON-compatible types.""" if isinstance(obj, dict): return { _key_to_str(k): _to_serializable(v) for k, v in obj.items() } elif isinstance(obj, (list, tuple)): return [_to_serializable(v) for v in obj] elif isinstance(obj, (np.integer,)): return int(obj) elif isinstance(obj, (np.floating,)): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, (np.bool_,)): return bool(obj) return obj def _key_to_str(key): """Convert dictionary keys to JSON-safe strings.""" if isinstance(key, tuple): return " → ".join(str(k) for k in key) return str(key) def _significance_str(pval): """Convert p-value to significance annotation string.""" if pval is None or np.isnan(pval): return "" if pval > 0.05: return "ns" elif pval > 0.01: return "*" elif pval > 0.001: return "**" elif pval > 0.0001: return "***" else: return "****" def _draw_significance( ax, model_positions, reference_model, fontsize=9, min_samples=5, ): """ Draw significance brackets from reference_model to each other model. Parameters ---------- ax Matplotlib Axes. model_positions Dict of ``{model_name: (x_position, data_array)}``. reference_model Name of the reference model to compare against. fontsize Font size for the significance text. min_samples Minimum number of data points required in both groups to perform the test. Returns ------- Maximum y coordinate used by brackets (for axis limit adjustment). """ if reference_model not in model_positions: return None ref_pos, ref_data = model_positions[reference_model] if len(ref_data) < min_samples: return None # Collect all visible data to determine bracket placement all_visible = [] for _, (_, d) in model_positions.items(): if len(d) > 0: q1, q3 = np.percentile(d, [25, 75]) iqr = q3 - q1 lower_whisker = max(np.min(d), q1 - 1.5 * iqr) all_visible.append(lower_whisker) if not all_visible: return None y_bottom = min(all_visible) y_max = ax.get_ylim()[1] y_range = y_max - y_bottom if y_max > y_bottom else 1.0 h = y_range * 0.03 # bracket tick height gap = y_range * 0.06 # vertical gap between stacked brackets margin = y_range * 0.08 # initial margin below data min_y = y_bottom bracket_idx = 0 # Sort other models by position so brackets don't cross others = sorted( [ (model, pos, data) for model, (pos, data) in model_positions.items() if model != reference_model and len(data) >= min_samples ], key=lambda x: abs(x[1] - ref_pos), ) for model, other_pos, other_data in others: # Mann-Whitney U test (two-sided) try: _, pval = mannwhitneyu( ref_data, other_data, alternative="two-sided" ) except ValueError: continue sig_str = _significance_str(pval) # Bracket y position (below the data) y_bar = y_bottom - margin - bracket_idx * gap # Draw bracket: two ticks pointing UP and a horizontal bar left = min(ref_pos, other_pos) right = max(ref_pos, other_pos) ax.plot( [left, left, right, right], [y_bar + h, y_bar, y_bar, y_bar + h], lw=0.8, color="black", clip_on=False, ) # Significance text (below the bar) weight = "normal" if sig_str == "ns" else "bold" ax.text( (left + right) / 2, y_bar - h * 0.3, sig_str, ha="center", va="top", fontsize=fontsize, color="black", fontweight=weight, ) min_y = min(min_y, y_bar - h * 2) bracket_idx += 1 # Expand y-axis to fit brackets if bracket_idx > 0: current_ylim = ax.get_ylim() ax.set_ylim(min_y - y_range * 0.05, current_ylim[1]) return min_y def _plot_aggregated(ax, df_m, all_models, model_colors, reference_model=None): """ One boxplot per model, pooling all groups and datasets. """ model_positions = {} for m_idx, model in enumerate(all_models): data = df_m.loc[df_m["model"] == model, "value"].dropna().values if len(data) == 0: continue bp = ax.boxplot( [data], positions=[m_idx], widths=0.6, patch_artist=True, showfliers=False, medianprops=dict(color="black", linewidth=1.5), ) bp["boxes"][0].set_facecolor(model_colors[model]) bp["boxes"][0].set_alpha(0.7) _overlay_points(ax, data, m_idx, model_colors[model]) model_positions[model] = (m_idx, data) ax.set_xticks(range(len(all_models))) ax.set_xticklabels(all_models, rotation=30, ha="right", fontsize=9) # Significance brackets if reference_model is not None: _draw_significance(ax, model_positions, reference_model, fontsize=10) def _plot_per_dataset(ax, df_m, all_models, model_colors, datasets_order=None, reference_model=None): """ Grouped by dataset, one boxplot per model within each group. """ if datasets_order is not None: dataset_list = datasets_order else: dataset_list = df_m["dataset"].unique() n_models = len(all_models) group_width = n_models + 1.5 tick_positions = [] tick_labels = [] # Collect positions per dataset for significance testing dataset_model_positions = {ds: {} for ds in dataset_list} for d_idx, dataset in enumerate(dataset_list): base_pos = d_idx * group_width df_d = df_m[df_m["dataset"] == dataset] for m_idx, model in enumerate(all_models): data = df_d.loc[ df_d["model"] == model, "value" ].dropna().values if len(data) == 0: continue pos = base_pos + m_idx bp = ax.boxplot( [data], positions=[pos], widths=0.6, patch_artist=True, showfliers=False, medianprops=dict(color="black", linewidth=1.5), ) bp["boxes"][0].set_facecolor(model_colors[model]) bp["boxes"][0].set_alpha(0.7) _overlay_points(ax, data, pos, model_colors[model]) dataset_model_positions[dataset][model] = (pos, data) tick_positions.append(base_pos + (n_models - 1) / 2) tick_labels.append(dataset) ax.set_xticks(tick_positions) ax.set_xticklabels(tick_labels, fontsize=12) #rotation=30, ha="right", # Vertical separators between datasets for d_idx in range(1, len(dataset_list)): sep_x = d_idx * group_width - group_width / 2 + (n_models - 1) / 2 ax.axvline( sep_x, color="gray", linewidth=0.5, linestyle="--", alpha=0.5 ) # Significance brackets per dataset if reference_model is not None: for dataset in dataset_list: _draw_significance( ax, dataset_model_positions[dataset], reference_model, fontsize=9, ) def _plot_per_group(ax, df_m, all_models, model_colors, reference_model=None): """ Full detail: one boxplot per (model, group), organized by dataset. """ dataset_list = df_m["dataset"].unique() n_models = len(all_models) model_width = 1.0 group_gap = 1.5 dataset_gap = 3.0 current_pos = 0.0 tick_positions = [] tick_labels = [] dataset_centers = [] dataset_boundaries = [] # Collect positions per (dataset, group) for significance group_model_positions = {} for d_idx, dataset in enumerate(dataset_list): df_d = df_m[df_m["dataset"] == dataset] groups = df_d["group"].unique() if d_idx > 0: dataset_boundaries.append(current_pos - dataset_gap / 2) current_pos += dataset_gap dataset_start = current_pos for g_idx, group in enumerate(groups): if g_idx > 0: current_pos += group_gap group_start = current_pos group_key = (dataset, group) group_model_positions[group_key] = {} for m_idx, model in enumerate(all_models): data = df_d.loc[ (df_d["group"] == group) & (df_d["model"] == model), "value", ].dropna().values pos = current_pos if len(data) > 0: bp = ax.boxplot( [data], positions=[pos], widths=0.7, patch_artist=True, showfliers=False, medianprops=dict(color="black", linewidth=1.5), ) bp["boxes"][0].set_facecolor(model_colors[model]) bp["boxes"][0].set_alpha(0.7) group_model_positions[group_key][model] = (pos, data) current_pos += model_width # Group label group_center = (group_start + current_pos - model_width) / 2 label = group if " → " in label: parts = label.split(" → ") label = "→".join(parts) tick_positions.append(group_center) tick_labels.append(label) dataset_end = current_pos - model_width dataset_centers.append((dataset_start + dataset_end) / 2) ax.set_xticks(tick_positions) ax.set_xticklabels(tick_labels, rotation=45, ha="right", fontsize=7) # Dataset separators for boundary in dataset_boundaries: ax.axvline( boundary, color="black", linewidth=1.0, linestyle="-", alpha=0.4 ) # Dataset names y_min, y_max = ax.get_ylim() for center, dataset in zip(dataset_centers, dataset_list): ax.text( center, y_min - (y_max - y_min) * 0.15, dataset, ha="center", va="top", fontsize=9, fontweight="bold", transform=ax.transData, ) # Significance brackets per group if reference_model is not None: for group_key, model_pos in group_model_positions.items(): _draw_significance( ax, model_pos, reference_model, fontsize=7, )
[docs] def benchmark_comparison( output_dir: str = "benchmark_results", models: Optional[List[str]] = None, models_order: Optional[List[str]] = None, datasets: Optional[List[str]] = None, datasets_order: Optional[List[str]] = None, metrics: Optional[List[str]] = None, detail: str = "aggregated", reference_model: Optional[str] = "velot", show_significance: bool = True, show_timing: bool = True, figsize_per_panel: tuple = (5, 4), show: bool = True, save: Optional[str] = None, ) -> plt.Figure: """ Compare benchmark results across models and datasets. Parameters ---------- output_dir Directory with saved JSON benchmark files. models Which models to include. None for all. datasets Which datasets to include. None for all. metrics Which metrics to plot (e.g., ``["cbdir", "iccoh"]``). None to auto-detect all metrics with per-cell data. detail Level of detail for the metric panels: - ``"aggregated"`` (default): one boxplot per model, pooling all groups and all datasets. - ``"per_dataset"``: one boxplot per (model, dataset), grouped by dataset, colored by model. - ``"per_group"``: one boxplot per (model, dataset, group), every edge/cluster visible, organized by dataset. reference_model Model to compare others against for significance testing. Set to None to disable significance brackets entirely. show_significance Whether to draw significance brackets. Only applies when ``reference_model`` is not None. show_timing Whether to include a timing comparison panel. figsize_per_panel Size of each subplot. show Display the plot. save Path to save. Returns ------- matplotlib Figure. Examples -------- Default — pooled, with significance vs velot:: velot.pl.benchmark_comparison("benchmark_results") Without significance:: velot.pl.benchmark_comparison( "benchmark_results", show_significance=False, ) Compare against a different reference:: velot.pl.benchmark_comparison( "benchmark_results", reference_model="scvelo_dynamical", ) Per dataset:: velot.pl.benchmark_comparison( "benchmark_results", detail="per_dataset", ) Full detail:: velot.pl.benchmark_comparison( "benchmark_results", detail="per_group", figsize_per_panel=(12, 4), ) """ if detail not in ("aggregated", "per_dataset", "per_group"): raise ValueError( f"detail must be 'aggregated', 'per_dataset', or " f"'per_group', got '{detail}'." ) df_summary = load_benchmarks(output_dir, models, datasets) df_cells = load_benchmarks_per_group(output_dir, models, datasets) if len(df_summary) == 0: raise ValueError(f"No benchmark results found in '{output_dir}'.") # Resolve reference model ref = None if show_significance and reference_model is not None: available_models = df_summary["model"].unique() if reference_model in available_models: ref = reference_model else: import warnings warnings.warn( f"Reference model '{reference_model}' not found in results. " f"Available: {list(available_models)}. " f"Significance brackets disabled.", UserWarning, ) # Auto-detect metrics if metrics is None: if len(df_cells) > 0: metrics = sorted(df_cells["metric"].unique().tolist()) else: metrics = [] n_metric_panels = len(metrics) n_panels = n_metric_panels + (1 if show_timing else 0) if n_panels == 0: raise ValueError("No metrics or timing data to plot.") fig, axes = plt.subplots( 1, n_panels, figsize=(figsize_per_panel[0] * n_panels, figsize_per_panel[1]), squeeze=False, ) axes = axes.flatten() if models_order is not None: all_models = models_order else: all_models = df_summary["model"].unique() n_models = len(all_models) cmap = plt.get_cmap("Set2") colors = [ cmap(0), cmap(1), cmap(2), cmap(3), cmap(5) ] # colors = plt.cm.Set2(np.linspace(0, 1, max(n_models, 1))) model_colors = dict(zip(all_models, colors)) # ── Metric panels ─────────────────────────────────────────── for i, metric_name in enumerate(metrics): ax = axes[i] df_m = df_cells[df_cells["metric"] == metric_name] if len(df_m) == 0: ax.set_title(f"{metric_name}\n(no data)") continue if detail == "aggregated": _plot_aggregated(ax, df_m, all_models, model_colors, ref) elif detail == "per_dataset": _plot_per_dataset(ax, df_m, all_models, model_colors, datasets_order, ref) elif detail == "per_group": _plot_per_group(ax, df_m, all_models, model_colors, ref) ax.set_title(metric_name.upper(), fontsize=12, fontweight="bold") ax.set_ylabel(metric_name) ax.set_ylim(-1.05, 1.05) ax.grid(axis="y", alpha=0.3) # ── Timing ────────────────────────────────────────────────── if show_timing and "time_total" in df_summary.columns: ax = axes[n_metric_panels] _plot_timing(ax, df_summary, all_models, model_colors, datasets_order) # ── Legend ────────────────────────────────────────────────── handles = [ plt.Rectangle( (0, 0), 1, 1, facecolor=model_colors[m], edgecolor="black", alpha=0.7, ) for m in all_models ] labels = list(all_models) # Mark reference model in legend if ref is not None: labels = [ f"{m} (ref)" if m == ref else m for m in all_models ] fig.legend( handles, labels, loc="upper center", ncol=min(len(all_models), 6), fontsize=12, frameon=False, bbox_to_anchor=(0.5, 1.02), ) plt.tight_layout(rect=[0, 0, 1, 0.95]) if save: os.makedirs(os.path.dirname(save) or ".", exist_ok=True) fig.savefig(save, dpi=300, bbox_inches="tight") if show: plt.show() return fig
[docs] def benchmark_comparison_individual( output_dir: str = "benchmark_results", models: Optional[List[str]] = None, models_order: Optional[List[str]] = None, datasets: Optional[List[str]] = None, datasets_order: Optional[List[str]] = None, metrics: Optional[List[str]] = None, detail: str = "aggregated", reference_model: Optional[str] = "velot", show_significance: bool = True, show_timing: bool = True, figsize: tuple = (6, 4), ylim_timing: Optional[int] = None, save: bool = False, save_path: str = None, save_prefix: str = "benchmark_boxplots", save_legend: bool = False, show: bool = True, ): df_summary = load_benchmarks(output_dir, models, datasets) df_cells = load_benchmarks_per_group(output_dir, models, datasets) if len(df_summary) == 0: raise ValueError(f"No benchmark results found in '{output_dir}'.") # Resolve reference model ref = None if show_significance and reference_model is not None: available_models = df_summary["model"].unique() if reference_model in available_models: ref = reference_model # Auto-detect metrics if metrics is None: if len(df_cells) > 0: metrics = sorted(df_cells["metric"].unique().tolist()) else: metrics = [] if models_order is not None: all_models = models_order else: all_models = df_summary["model"].unique() n_models = len(all_models) cmap = plt.get_cmap("Set2") colors = [cmap(0), cmap(1), cmap(2), cmap(3), cmap(5)] model_colors = dict(zip(all_models, colors)) label_idx = 0 # ── Metric panels ─────────────────────────────────────────── for metric_name in metrics: fig, ax = plt.subplots(figsize=figsize) df_m = df_cells[df_cells["metric"] == metric_name] if len(df_m) == 0: ax.text(0.5, 0.5, f"{metric_name}\n(no data)", ha="center", va="center", transform=ax.transAxes) else: if detail == "aggregated": _plot_aggregated(ax, df_m, all_models, model_colors, ref) elif detail == "per_dataset": _plot_per_dataset(ax, df_m, all_models, model_colors, datasets_order, ref) elif detail == "per_group": _plot_per_group(ax, df_m, all_models, model_colors, ref) # ax.set_title(metric_name.upper(), fontsize=12, fontweight="bold") ax.set_ylabel(metric_name, fontsize=12) ax.set_ylim(-1.05, 1.05) ax.grid(axis="y", alpha=0.3) plt.tight_layout() if save: # os.makedirs(os.path.dirname(save) or ".", exist_ok=True) fig.savefig(f"{save_path}/{save_prefix}_{chr(ord('a') + label_idx)}.png", dpi=300, bbox_inches="tight") if show: plt.show() label_idx += 1 # fig.savefig(f"{save_prefix}_{chr(ord('a') + label_idx)}.png", # dpi=300, bbox_inches="tight", facecolor="white") # if show: # plt.show() # plt.close(fig) # print(f"Saved {save_prefix}_{chr(ord('a') + label_idx)}.png") # label_idx += 1 # ── Timing panel ──────────────────────────────────────────── if show_timing and "time_total" in df_summary.columns: fig, ax = plt.subplots(figsize=figsize) _plot_timing(ax, df_summary, all_models, model_colors, datasets_order) if ylim_timing is not None: ax.set_ylim(0, ylim_timing) plt.tight_layout() if save: # os.makedirs(os.path.dirname(save) or ".", exist_ok=True) fig.savefig(f"{save_path}/{save_prefix}_{chr(ord('a') + label_idx)}.png", dpi=300, bbox_inches="tight") if show: plt.show() # fig.savefig(f"{save_prefix}_{chr(ord('a') + label_idx)}.png", # dpi=300, bbox_inches="tight", facecolor="white") # if show: # plt.show() # plt.close(fig) # print(f"Saved {save_prefix}_{chr(ord('a') + label_idx)}.png") # ── Standalone legend ─────────────────────────────────────── if save_legend: fig_leg = plt.figure(figsize=(0.1, 0.1)) handles = [ plt.Rectangle( (0, 0), 1, 1, facecolor=model_colors[m], edgecolor="black", alpha=0.7, ) for m in all_models ] labels = list(all_models) if ref is not None: labels = [f"{m} (ref)" if m == ref else m for m in all_models] leg = fig_leg.legend( handles, labels, loc="center", ncol=len(all_models), fontsize=12, frameon=False, ) fig_leg.canvas.draw() bbox = leg.get_window_extent().transformed(fig_leg.dpi_scale_trans.inverted()) fig_leg.savefig(f"{save_path}/{save_prefix}_legend.png", dpi=300, bbox_inches=bbox) # fig_leg.savefig(f"{save_prefix}_legend.png", dpi=300, bbox_inches=bbox, facecolor="white") plt.close(fig_leg) print(f"Saved {save_prefix}_legend.png")
def _plot_timing(ax, df_summary, all_models, model_colors, datasets_order=None): """Grouped bar chart of execution time.""" if datasets_order is not None: dataset_list = datasets_order else: dataset_list = df_summary["dataset"].unique() x = np.arange(len(dataset_list)) n_models = len(all_models) bar_width = 0.8 / max(n_models, 1) for m_idx, model in enumerate(all_models): times = [] for dataset in dataset_list: mask = ( (df_summary["model"] == model) & (df_summary["dataset"] == dataset) ) t = df_summary.loc[mask, "time_total"].values times.append(t[0] if len(t) > 0 else 0) offset = (m_idx - n_models / 2 + 0.5) * bar_width bars = ax.bar( x + offset, times, bar_width * 0.9, label=model, color=model_colors[model], edgecolor="black", linewidth=0.5, alpha=0.9 ) for bar, t in zip(bars, times): if t > 0: label = f"{t:.1f}s" if t < 60 else f"{t / 60:.1f}m" ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.1, label, ha="center", va="bottom", fontsize=12, rotation=0, ) ax.set_xticks(x) ax.set_xticklabels( dataset_list, fontsize=12 ) ax.set_ylabel("Time (seconds)") # ax.set_title("Execution time", fontsize=12, fontweight="bold") # Log scale if there's a large spread times_all = df_summary["time_total"].dropna() if len(times_all) > 1 and times_all.max() / max(times_all.min(), 0.1) > 10: ax.set_yscale("log") ax.grid(axis="y", alpha=0.3) def _overlay_points(ax, data, position, color, max_points=200): """Overlay jittered individual points on a boxplot.""" n_show = min(len(data), max_points) if n_show < len(data): show_data = np.random.RandomState(42).choice( data, n_show, replace=False ) else: show_data = data jitter = np.random.RandomState(42).normal(0, 0.08, len(show_data)) ax.scatter( np.full(len(show_data), position) + jitter, show_data, s=8, c=[color], edgecolors="black", linewidths=0.2, zorder=5, alpha=0.5, )
[docs] def benchmark_summary_table( output_dir: str = "benchmark_results", models: Optional[List[str]] = None, datasets: Optional[List[str]] = None, metrics: Optional[List[str]] = None, reference_model: str = "velot", detail: str = "aggregated", save_csv: Optional[str] = None, ) -> pd.DataFrame: """ Build a summary DataFrame with statistics and p-values. Parameters ---------- output_dir Directory containing JSON result files. models Filter to specific models. None for all. datasets Filter to specific datasets. None for all. metrics Which metrics to include. None for all. reference_model Model to compare others against. P-values are computed between this model and each other model. detail Aggregation level: - ``"aggregated"``: one row per (model, metric), pooling all datasets and groups. - ``"per_dataset"``: one row per (model, dataset, metric), pooling groups within each dataset. - ``"per_group"``: one row per (model, dataset, group, metric), no aggregation. save_csv Path to save the DataFrame as CSV. None to skip saving. Returns ------- DataFrame with columns including mean, median, std, n, and p-value vs the reference model. Examples -------- Quick overview:: df = velot.benchmark.benchmark_summary_table("benchmark_results") print(df) Per dataset with CSV export:: df = velot.benchmark.benchmark_summary_table( "benchmark_results", detail="per_dataset", save_csv="benchmark_results/summary_per_dataset.csv", ) Full detail:: df = velot.benchmark.benchmark_summary_table( "benchmark_results", detail="per_group", ) """ if detail not in ("aggregated", "per_dataset", "per_group"): raise ValueError( f"detail must be 'aggregated', 'per_dataset', or " f"'per_group', got '{detail}'." ) df_cells = load_benchmarks_per_group(output_dir, models, datasets) df_summary = load_benchmarks(output_dir, models, datasets) if len(df_cells) == 0: raise ValueError(f"No benchmark results found in '{output_dir}'.") # Auto-detect metrics if metrics is None: metrics = sorted(df_cells["metric"].unique().tolist()) else: df_cells = df_cells[df_cells["metric"].isin(metrics)] # Add timing info per (model, dataset) for inclusion in the table timing_map = {} if len(df_summary) > 0 and "time_total" in df_summary.columns: for _, row in df_summary.iterrows(): timing_map[(row["model"], row["dataset"])] = { "time_total": row.get("time_total", np.nan), "time_velocity": row.get("time_velocity", np.nan), } # ── Define grouping keys ──────────────────────────────────── if detail == "aggregated": group_keys = ["model", "metric"] elif detail == "per_dataset": group_keys = ["model", "dataset", "metric"] elif detail == "per_group": group_keys = ["model", "dataset", "metric", "group"] # ── Compute summary stats per group ───────────────────────── records = [] for keys, grp in df_cells.groupby(group_keys): if not isinstance(keys, tuple): keys = (keys,) row = dict(zip(group_keys, keys)) values = grp["value"].dropna().values row["mean"] = float(np.nanmean(values)) row["median"] = float(np.nanmedian(values)) row["std"] = float(np.nanstd(values)) row["n"] = int(len(values)) # Store raw values for p-value computation later row["_values"] = values records.append(row) df = pd.DataFrame(records) # ── Compute p-values vs reference ─────────────────────────── df["pvalue"] = np.nan df["significance"] = "" if reference_model is not None and reference_model in df["model"].values: # Define the matching keys (everything except model) match_keys = [k for k in group_keys if k != "model"] # Build lookup for reference model data ref_mask = df["model"] == reference_model ref_rows = df[ref_mask] ref_lookup = {} for idx, row in ref_rows.iterrows(): match_key = tuple(row[k] for k in match_keys) ref_lookup[match_key] = row["_values"] # Compute p-values for non-reference models for idx, row in df.iterrows(): if row["model"] == reference_model: continue match_key = tuple(row[k] for k in match_keys) ref_data = ref_lookup.get(match_key) if ref_data is None or len(ref_data) < 5: continue other_data = row["_values"] if len(other_data) < 5: continue try: _, pval = mannwhitneyu( ref_data, other_data, alternative="two-sided" ) df.at[idx, "pvalue"] = pval df.at[idx, "significance"] = _significance_str(pval) except ValueError: continue # Also mark the reference rows df.loc[ref_mask, "significance"] = "ref" # ── Clean up ──────────────────────────────────────────────── df = df.drop(columns=["_values"]) # Add timing columns for per_dataset and per_group detail if detail in ("per_dataset", "per_group") and timing_map: df["time_total"] = df.apply( lambda r: timing_map.get( (r["model"], r.get("dataset", "")), {} ).get("time_total", np.nan), axis=1, ) # For aggregated, add mean timing across datasets if detail == "aggregated" and timing_map: model_times = {} for (model, _), t in timing_map.items(): model_times.setdefault(model, []).append(t["time_total"]) df["time_total_mean"] = df["model"].map( {m: np.nanmean(ts) for m, ts in model_times.items()} ) # Sort for readability sort_cols = [k for k in group_keys if k in df.columns] df = df.sort_values(sort_cols).reset_index(drop=True) # Reorder columns for readability front_cols = [k for k in group_keys if k in df.columns] stat_cols = ["mean", "median", "std", "n", "pvalue", "significance"] time_cols = [c for c in df.columns if c.startswith("time_")] other_cols = [ c for c in df.columns if c not in front_cols + stat_cols + time_cols ] col_order = front_cols + stat_cols + time_cols + other_cols col_order = [c for c in col_order if c in df.columns] df = df[col_order] # ── Save ──────────────────────────────────────────────────── if save_csv is not None: os.makedirs(os.path.dirname(save_csv) or ".", exist_ok=True) df.to_csv(save_csv, index=False, float_format="%.6f") print(f" Saved summary table: {save_csv}") return df
[docs] def benchmark_heatmaps( output_dir: str = "benchmark_results", models: Optional[List[str]] = None, datasets: Optional[List[str]] = None, metrics: Optional[List[str]] = None, stat: str = "mean", annot_fmt: str = ".3f", cmap: str = "viridis", figsize_per_panel: tuple = (5, 3), show: bool = True, save: Optional[str] = None, ) -> plt.Figure: """ Heatmap of model × dataset for each metric. Parameters ---------- output_dir Directory with saved JSON benchmark files. models Which models to include. None for all. datasets Which datasets to include. None for all. metrics Which metrics to plot. None for all. stat Statistic to display in each cell: - ``"mean"``: mean of per-cell values. - ``"median"``: median of per-cell values. - ``"rank"``: average rank across groups within each dataset. Rank 1 = best. Averaged across all edges/clusters. annot_fmt Format string for cell annotations. cmap Colormap name. For rank, this is automatically reversed (lower rank = better = greener). figsize_per_panel Size per metric panel. show Display the plot. save Path to save. Returns ------- matplotlib Figure. Examples -------- Mean performance heatmap:: velot.pl.benchmark_heatmaps("benchmark_results", stat="mean") Rank heatmap:: velot.pl.benchmark_heatmaps("benchmark_results", stat="rank") Specific metrics:: velot.pl.benchmark_heatmaps( "benchmark_results", metrics=["cbdir"], stat="median", ) """ if stat not in ("mean", "median", "rank"): raise ValueError( f"stat must be 'mean', 'median', or 'rank', got '{stat}'." ) df_cells = load_benchmarks_per_group(output_dir, models, datasets) if len(df_cells) == 0: raise ValueError(f"No benchmark results found in '{output_dir}'.") if metrics is None: metrics = sorted(df_cells["metric"].unique().tolist()) n_metrics = len(metrics) if n_metrics == 0: raise ValueError("No metrics to plot.") # Include timing as an extra panel df_summary = load_benchmarks(output_dir, models, datasets) has_timing = "time_total" in df_summary.columns n_panels = n_metrics + (1 if has_timing else 0) fig, axes = plt.subplots( 1, n_panels, figsize=(figsize_per_panel[0] * n_panels, figsize_per_panel[1]), squeeze=False, ) axes = axes.flatten() for i, metric_name in enumerate(metrics): ax = axes[i] df_m = df_cells[df_cells["metric"] == metric_name] if len(df_m) == 0: ax.set_title(f"{metric_name}\n(no data)") continue if stat == "rank": pivot = _compute_rank_pivot(df_m) _draw_heatmap( ax, pivot, title=f"{metric_name.upper()} (avg rank)", cmap=cmap + "_r", # reverse: low rank = good = green annot_fmt=".2f", vmin=1, vmax=len(pivot.index), lower_is_better=True, ) else: pivot = _compute_stat_pivot(df_m, stat) _draw_heatmap( ax, pivot, title=f"{metric_name.upper()} ({stat})", cmap=cmap, annot_fmt=annot_fmt, lower_is_better=False, ) # Timing heatmap if has_timing: ax = axes[n_metrics] pivot_time = df_summary.pivot_table( values="time_total", index="model", columns="dataset", aggfunc="first", ) if len(pivot_time) > 0: _draw_heatmap( ax, pivot_time, title="Time (seconds)", cmap=cmap + "_r", # lower = better = green annot_fmt=".1f", lower_is_better=True, ) plt.tight_layout() if save: os.makedirs(os.path.dirname(save) or ".", exist_ok=True) fig.savefig(save, dpi=150, bbox_inches="tight") if show: plt.show() return fig
[docs] def benchmark_ranking( output_dir: str = "benchmark_results", models: Optional[List[str]] = None, models_order: Optional[List[str]] = None, datasets: Optional[List[str]] = None, metrics: Optional[List[str]] = None, include_timing: bool = True, figsize: tuple = (8, 5), show: bool = True, save: Optional[str] = None, ) -> plt.Figure: """ Overall ranking summary: average rank across all datasets and groups. For each (dataset, group, metric) combination, models are ranked 1 to N (1 = best). Ranks are then averaged to produce a single score per model. Lower is better. Optionally includes execution time as a ranked criterion. Parameters ---------- output_dir Directory with saved JSON benchmark files. models Which models to include. None for all. datasets Which datasets to include. None for all. metrics Which metrics to include. None for all. include_timing Whether to include execution time as a ranking criterion. figsize Figure size. show Display the plot. save Path to save. Returns ------- matplotlib Figure. Examples -------- :: velot.pl.benchmark_ranking("benchmark_results") """ df_cells = load_benchmarks_per_group(output_dir, models, datasets) df_summary = load_benchmarks(output_dir, models, datasets) if len(df_cells) == 0: raise ValueError(f"No benchmark results found in '{output_dir}'.") if metrics is None: metrics = sorted(df_cells["metric"].unique().tolist()) # ── Compute ranks per (dataset, group, metric) ────────────── rank_records = [] for metric_name in metrics: df_m = df_cells[df_cells["metric"] == metric_name] for (dataset, group), grp in df_m.groupby(["dataset", "group"]): # Compute mean per model for this (dataset, group, metric) model_means = grp.groupby("model")["value"].mean() # Rank: higher value = better → rank ascending=False ranks = model_means.rank(ascending=False, method="average") for model, rank in ranks.items(): rank_records.append({ "model": model, "dataset": dataset, "group": group, "metric": metric_name, "criterion": metric_name, "rank": rank, }) # ── Add timing ranks per dataset ──────────────────────────── if include_timing and "time_total" in df_summary.columns: for dataset in df_summary["dataset"].unique(): df_d = df_summary[df_summary["dataset"] == dataset] times = df_d.set_index("model")["time_total"] # Lower time = better → rank ascending=True ranks = times.rank(ascending=True, method="average") for model, rank in ranks.items(): rank_records.append({ "model": model, "dataset": dataset, "group": "_timing_", "metric": "time", "criterion": "time", "rank": rank, }) df_ranks = pd.DataFrame(rank_records) if len(df_ranks) == 0: raise ValueError("No ranking data computed.") # ── Aggregate ─────────────────────────────────────────────── # Average rank per (model, criterion) avg_per_criterion = ( df_ranks.groupby(["model", "criterion"])["rank"] .mean() .reset_index() .rename(columns={"rank": "avg_rank"}) ) # Overall average rank per model avg_overall = ( df_ranks.groupby("model")["rank"] .mean() .reset_index() .rename(columns={"rank": "avg_rank"}) .sort_values("avg_rank") ) # ── Plot ──────────────────────────────────────────────────── all_models = avg_overall["model"].values n_models = len(all_models) criteria = sorted(avg_per_criterion["criterion"].unique()) n_criteria = len(criteria) fig, axes = plt.subplots( 1, 2, figsize=figsize, gridspec_kw={"width_ratios": [2, 3]}, ) # --- Left panel: overall average rank bar chart --- ax = axes[0] cmap = plt.get_cmap("Set2") colors = [ cmap(0), cmap(1), cmap(2), cmap(3), cmap(5) ] # colors = plt.cm.Set2(np.linspace(0, 1, max(n_models, 1))) if models_order is not None: model_colors = dict(zip(models_order, colors)) else: model_colors = dict(zip(all_models, colors)) # colors = plt.cm.Set2(np.linspace(0, 1, max(n_models, 1))) # model_colors = dict(zip(all_models, colors)) bars = ax.barh( range(n_models), avg_overall["avg_rank"].values, color=[model_colors[m] for m in all_models], edgecolor="black", linewidth=0.5, ) ax.set_yticks(range(n_models)) ax.set_yticklabels(all_models, fontsize=10) ax.set_xlabel("Average rank (lower = better)", fontsize=10) ax.set_title("Overall ranking", fontsize=12, fontweight="bold") ax.invert_yaxis() # best (lowest rank) at top # Annotate bars for bar, val in zip(bars, avg_overall["avg_rank"].values): ax.text( bar.get_width() + 0.05, bar.get_y() + bar.get_height() / 2, f"{val:.2f}", ha="left", va="center", fontsize=9, fontweight="bold", ) ax.set_xlim(0, n_models + 0.5) ax.axvline(1, color="green", linewidth=0.8, linestyle="--", alpha=0.5, label="Best possible (1.0)") ax.grid(axis="x", alpha=0.3) # --- Right panel: rank breakdown by criterion --- ax = axes[1] pivot_rank = avg_per_criterion.pivot_table( values="avg_rank", index="model", columns="criterion", ) # Reorder models by overall rank pivot_rank = pivot_rank.loc[all_models] x = np.arange(n_criteria) bar_width = 0.8 / max(n_models, 1) for m_idx, model in enumerate(all_models): if model not in pivot_rank.index: continue values = pivot_rank.loc[model].values offset = (m_idx - n_models / 2 + 0.5) * bar_width ax.bar( x + offset, values, bar_width * 0.9, color=model_colors[model], edgecolor="black", linewidth=0.5, label=model, ) ax.set_xticks(x) ax.set_xticklabels( [c.upper() for c in criteria], rotation=30, ha="right", fontsize=9, ) ax.set_ylabel("Average rank (lower = better)", fontsize=10) ax.set_title("Rank by criterion", fontsize=12, fontweight="bold") ax.axhline(1, color="green", linewidth=0.8, linestyle="--", alpha=0.5) ax.grid(axis="y", alpha=0.3) ax.legend(fontsize=8, frameon=False, loc="upper left") plt.tight_layout() if save: os.makedirs(os.path.dirname(save) or ".", exist_ok=True) fig.savefig(save, dpi=150, bbox_inches="tight") if show: plt.show() return fig
# ── Internal helpers ──────────────────────────────────────────── def _compute_stat_pivot(df_m, stat="mean"): """ Pivot: model × dataset, values = mean or median pooling all groups. """ agg_func = "mean" if stat == "mean" else "median" pivot = df_m.pivot_table( values="value", index="model", columns="dataset", aggfunc=agg_func, ) return pivot def _compute_rank_pivot(df_m): """ Pivot: model × dataset, values = average rank across groups. For each (dataset, group), rank models (1=best). Then average across groups within each dataset. """ rank_records = [] for (dataset, group), grp in df_m.groupby(["dataset", "group"]): model_means = grp.groupby("model")["value"].mean() # Higher value = better → rank ascending=False ranks = model_means.rank(ascending=False, method="average") for model, rank in ranks.items(): rank_records.append({ "model": model, "dataset": dataset, "rank": rank, }) df_ranks = pd.DataFrame(rank_records) pivot = df_ranks.pivot_table( values="rank", index="model", columns="dataset", aggfunc="mean", ) return pivot def _draw_heatmap( ax, pivot, title="", cmap="viridis", annot_fmt=".3f", vmin=None, vmax=None, lower_is_better=False, ): """ Draw an annotated heatmap on the given axes. Highlights the best value in each column (dataset) with a bold border. """ data = pivot.values n_rows, n_cols = data.shape if vmin is None: vmin = np.nanmin(data) if vmax is None: vmax = np.nanmax(data) # Draw the heatmap im = ax.imshow( data, cmap=cmap, aspect="auto", vmin=vmin, vmax=vmax, ) # Find best row per column FIRST best_rows = {} for c in range(n_cols): col_data = data[:, c] valid = ~np.isnan(col_data) if not valid.any(): continue if lower_is_better: best_rows[c] = int(np.nanargmin(col_data)) else: best_rows[c] = int(np.nanargmax(col_data)) # Annotate cells for r in range(n_rows): for c in range(n_cols): val = data[r, c] if np.isnan(val): text = "—" color = "gray" weight = "normal" else: text = f"{val:{annot_fmt}}" normalized = (val - vmin) / max(vmax - vmin, 1e-10) if lower_is_better: normalized = 1 - normalized color = "black" if normalized < 0.5 else "white" weight = "bold" if best_rows.get(c) == r else "normal" ax.text( c, r, text, ha="center", va="center", fontsize=10, fontweight=weight, color=color, ) # Labels ax.set_xticks(range(n_cols)) ax.set_xticklabels( pivot.columns, rotation=30, ha="right", fontsize=9, ) ax.set_yticks(range(n_rows)) ax.set_yticklabels(pivot.index, fontsize=10) ax.set_title(title, fontsize=12, fontweight="bold", pad=10) # Colorbar cbar = ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04) cbar.ax.tick_params(labelsize=8)
[docs] def benchmark_dotplot( output_dir: str = "benchmark_results", models: Optional[List[str]] = None, datasets: Optional[List[str]] = None, datasets_order: Optional[List[str]] = None, metrics: Optional[List[str]] = None, include_timing: bool = True, normalize_mode: str = "global", stat: str = "mean", metric_colors: Optional[Dict[str, str]] = None, metric_labels: Optional[Dict[str, str]] = None, model_labels: Optional[Dict[str, str]] = None, higher_is_better: Optional[Dict[str, bool]] = None, figsize: Optional[tuple] = None, legend: bool = True, summary: bool = True, show: bool = True, save: Optional[str] = None, ) -> plt.Figure: """ Integrated dotplot benchmark summary. Circle area reflects per-column normalised score (for visual comparison only). Displayed values match the heatmap values for the chosen ``stat``. Parameters ---------- output_dir Directory with saved JSON benchmark files. models Which models to include. None for all. datasets Which datasets to include. None for all. metrics Which metrics to include. None for all found. include_timing Whether to include execution time as a metric group. stat Aggregation statistic — same as in ``benchmark_heatmaps``: - ``"mean"``: mean of per-cell values per (model, dataset). - ``"median"``: median of per-cell values. - ``"rank"``: average rank across groups within each dataset (1 = best). Consistent with heatmap ``stat="rank"``. metric_colors Dict mapping metric name → color. metric_labels Dict mapping metric name → display label. model_labels Dict mapping model name → display label. higher_is_better Dict mapping metric name → bool. Only used for normalisation direction and ``stat="rank"``. Default True for accuracy metrics, False for time. figsize Figure size. Auto-computed if None. show Display the plot. save Path to save. Returns ------- matplotlib Figure. """ df_cells = load_benchmarks_per_group(output_dir, models, datasets) df_summary = load_benchmarks(output_dir, models, datasets) if len(df_cells) == 0: raise ValueError(f"No benchmark results found in '{output_dir}'.") # ── Resolve metrics ───────────────────────────────────────── available_metrics = sorted(df_cells["metric"].unique().tolist()) if metrics is None: metrics = available_metrics else: metrics = [m for m in metrics if m in available_metrics] if len(metrics) == 0: raise ValueError("No metrics found to plot.") all_models = sorted(df_summary["model"].unique().tolist()) if datasets_order is not None: all_datasets = datasets_order else: all_datasets = sorted(df_summary["dataset"].unique().tolist()) n_models = len(all_models) n_datasets = len(all_datasets) # ── Defaults ──────────────────────────────────────────────── default_palette = ["#c0392b", "#2874a6", "#8e44ad", "#d35400", "#16a085", "#7d3c98", "#2c3e50"] if metric_colors is None: metric_colors = {} for i, m in enumerate(metrics): if m not in metric_colors: metric_colors[m] = default_palette[i % len(default_palette)] if include_timing: metric_colors.setdefault("time", "#1e8449") if higher_is_better is None: higher_is_better = {} for m in metrics: higher_is_better.setdefault(m, True) higher_is_better.setdefault("time", False) if metric_labels is None: metric_labels = {} for m in metrics: arrow = "↑" if higher_is_better.get(m, True) else "↓" metric_labels.setdefault(m, f"{m.upper()} {arrow}") metric_labels.setdefault("time", "Time ↓ (seconds)") if model_labels is None: model_labels = {} for m in all_models: model_labels.setdefault(m, m) # ── Build RAW data matrices (same as heatmaps) ────────────── raw_matrices = {} # metric → (n_models, n_datasets) for metric_name in metrics: df_m = df_cells[df_cells["metric"] == metric_name] if stat == "rank": # Rank per (dataset, group), then average within dataset # Same logic as _compute_rank_pivot in heatmaps mat = np.full((n_models, n_datasets), np.nan) for di, dataset in enumerate(all_datasets): df_d = df_m[df_m["dataset"] == dataset] groups = df_d["group"].unique() model_ranks_all = {m: [] for m in all_models} for group in groups: df_g = df_d[df_d["group"] == group] model_means = df_g.groupby("model")["value"].mean() hb = higher_is_better.get(metric_name, True) ranks = model_means.rank( ascending=not hb, method="average" ) for model in all_models: if model in ranks.index: model_ranks_all[model].append(ranks[model]) for mi, model in enumerate(all_models): if model_ranks_all[model]: mat[mi, di] = np.mean(model_ranks_all[model]) raw_matrices[metric_name] = mat else: # Mean or median of per-cell values agg_func = np.nanmean if stat == "mean" else np.nanmedian mat = np.full((n_models, n_datasets), np.nan) for mi, model in enumerate(all_models): for di, dataset in enumerate(all_datasets): mask = ( (df_m["model"] == model) & (df_m["dataset"] == dataset) ) vals = df_m.loc[mask, "value"].dropna().values if len(vals) > 0: mat[mi, di] = agg_func(vals) raw_matrices[metric_name] = mat # Time matrix has_timing = include_timing and "time_total" in df_summary.columns if has_timing: if stat == "rank": time_mat = np.full((n_models, n_datasets), np.nan) for di, dataset in enumerate(all_datasets): df_d = df_summary[df_summary["dataset"] == dataset] times = df_d.set_index("model")["time_total"] ranks = times.rank(ascending=True, method="average") for mi, model in enumerate(all_models): if model in ranks.index: time_mat[mi, di] = ranks[model] else: time_mat = np.full((n_models, n_datasets), np.nan) for mi, model in enumerate(all_models): for di, dataset in enumerate(all_datasets): mask = ( (df_summary["model"] == model) & (df_summary["dataset"] == dataset) ) vals = df_summary.loc[mask, "time_total"].values if len(vals) > 0: time_mat[mi, di] = vals[0] raw_matrices["time"] = time_mat all_metric_keys = metrics + ["time"] else: all_metric_keys = list(metrics) # ── Compute OVERALL per metric (same as heatmap) ──────────── # For mean/median: average of raw values across datasets # For rank: average of ranks across datasets raw_overall = {} for mk in all_metric_keys: raw_overall[mk] = np.nanmean(raw_matrices[mk], axis=1) # ── Grand score (for ranking methods) ─────────────────────── if stat == "rank": # Average of average ranks — lower is better grand = np.nanmean( np.column_stack([raw_overall[mk] for mk in all_metric_keys]), axis=1, ) order = np.argsort(grand) # lower rank = better = first grand_higher_better = False else: # Normalise each metric's overall to [0,1] then average # This is needed because metrics have different scales norm_overalls = [] for mk in all_metric_keys: ov = raw_overall[mk].copy() valid = ~np.isnan(ov) if valid.sum() == 0: norm_overalls.append(np.full_like(ov, np.nan)) continue vmin, vmax = ov[valid].min(), ov[valid].max() hb = higher_is_better.get(mk, True) if mk == "time" and stat != "rank": # Log transform time before normalising ov_t = np.full_like(ov, np.nan) ov_t[valid] = np.log(ov[valid]) vmin, vmax = ov_t[valid].min(), ov_t[valid].max() ov = ov_t hb = False if vmax == vmin: norm_overalls.append(np.where(valid, 1.0, np.nan)) continue normed = np.full_like(ov, np.nan) normed[valid] = (ov[valid] - vmin) / (vmax - vmin) if not hb: normed[valid] = 1.0 - normed[valid] norm_overalls.append(normed) grand = np.nanmean(np.column_stack(norm_overalls), axis=1) order = np.argsort(-grand) # higher = better = first grand_higher_better = True models_sorted = [all_models[i] for i in order] # ── Normalise for DOT SIZING only ─────────────────────────── def _norm_for_sizing(arr, hb=True, log_transform=False, mode="global"): """ Map values to [0, 1] for circle area. mode: "column" → per-column, "row" → per-row, "global" → whole matrix. """ out = np.full_like(arr, np.nan, dtype=float) # Apply log transform first if needed work = arr.astype(float).copy() if log_transform: valid_all = ~np.isnan(work) work[valid_all] = np.log(work[valid_all]) if mode == "column": for j in range(work.shape[1]): col = work[:, j] valid = ~np.isnan(col) if valid.sum() == 0: continue vmin, vmax = col[valid].min(), col[valid].max() if vmax == vmin: out[valid, j] = 1.0 continue s = (col[valid] - vmin) / (vmax - vmin) if not hb: s = 1.0 - s out[valid, j] = s elif mode == "row": for i in range(work.shape[0]): row = work[i, :] valid = ~np.isnan(row) if valid.sum() == 0: continue vmin, vmax = row[valid].min(), row[valid].max() if vmax == vmin: out[i, valid] = 1.0 continue s = (row[valid] - vmin) / (vmax - vmin) if not hb: s = 1.0 - s out[i, valid] = s elif mode == "global": valid = ~np.isnan(work) if valid.sum() == 0: return out vmin, vmax = work[valid].min(), work[valid].max() if vmax == vmin: out[valid] = 1.0 return out s = (work[valid] - vmin) / (vmax - vmin) if not hb: s = 1.0 - s out[valid] = s return out size_matrices = {} for mk in all_metric_keys: hb = higher_is_better.get(mk, True) if stat == "rank": size_matrices[mk] = _norm_for_sizing( raw_matrices[mk], hb=False, log_transform=False, mode=normalize_mode, ) elif mk == "time": size_matrices[mk] = _norm_for_sizing( raw_matrices[mk], hb=False, log_transform=True, mode=normalize_mode, ) else: size_matrices[mk] = _norm_for_sizing( raw_matrices[mk], hb=hb, log_transform=False, mode=normalize_mode, ) # Overall sizing size_overall = {} for mk in all_metric_keys: size_overall[mk] = np.nanmean(size_matrices[mk], axis=1) # ── Build column layout ───────────────────────────────────── cols = [] group_spans = [] x = 0 # Determine annotation format if stat == "rank": annot_fmt = ".2f" # ranks like 1.24 else: annot_fmt = ".2f" # raw values like 0.829 time_annot_fmt = ".1f" if stat != "rank" else ".2f" for mk in all_metric_keys: gcolor = metric_colors[mk] glabel = metric_labels[mk] g_start = x for di, dataset in enumerate(all_datasets): cols.append({ "x": x, "color": gcolor, "label": dataset, "kind": "dataset", "raw_values": raw_matrices[mk][:, di], "size_values": size_matrices[mk][:, di], "fmt": time_annot_fmt if mk == "time" else annot_fmt, }) x += 1 cols.append({ "x": x, "color": gcolor, "label": "Overall", "kind": "overall", "raw_values": raw_overall[mk], "size_values": size_overall[mk], "fmt": time_annot_fmt if mk == "time" else annot_fmt, }) x += 1 group_spans.append((glabel, gcolor, g_start - 0.45, x - 1 + 0.45)) x += 0.7 m2i = {m: i for i, m in enumerate(all_models)} # ── Figure setup ──────────────────────────────────────────── if figsize is None: width = max(12, len(cols) * 0.9 + 6) height = max(4, n_models * 0.9 + 2.5) figsize = (width, height) fig = plt.figure(figsize=figsize) ax = fig.add_axes([0.14, 0.16, 0.68, 0.64]) # ── Zebra striping ────────────────────────────────────────── for row in range(n_models): if row % 2 == 0: ax.add_patch(plt.Rectangle( (-0.7, row - 0.48), x + 8, 0.96, color="#f6f6f6", zorder=0, lw=0, )) # ── Find best value per column for highlighting ───────────── best_per_col = {} for ci, col_info in enumerate(cols): raw_vals = col_info["raw_values"] valid = ~np.isnan(raw_vals) if not valid.any(): continue mk = None for mk_name in all_metric_keys: if col_info["color"] == metric_colors[mk_name]: mk = mk_name break hb = higher_is_better.get(mk, True) if stat == "rank": hb = False # lower rank = better if mk == "time" and stat != "rank": hb = False # lower time = better if hb: best_per_col[ci] = np.nanmax(raw_vals) else: best_per_col[ci] = np.nanmin(raw_vals) for row, m in enumerate(models_sorted): mi = m2i[m] for ci, col_info in enumerate(cols): cx = col_info["x"] gcolor = col_info["color"] kind = col_info["kind"] raw_val = col_info["raw_values"][mi] size_val = col_info["size_values"][mi] fmt = col_info["fmt"] if np.isnan(raw_val): ax.plot( cx, row, marker="x", color="#888", markersize=11, mew=2, zorder=3, ) continue # Check if this is the best in its column is_best = ( ci in best_per_col and np.isclose(raw_val, best_per_col[ci], rtol=1e-6) ) dot_size = 60 + 1100 * max(size_val, 0) ax.scatter( cx, row, s=dot_size, color=gcolor, #color="gold" if is_best else gcolor, alpha=1 if is_best else 0.75, edgecolor="black", linewidth=1.0, zorder=4 if is_best else 3, ) # Annotation text = f"{raw_val:{fmt}}" if size_val >= 0.30: ax.text( cx, row, text, ha="center", va="center", fontsize=7.5, color="white", #"black" if is_best else "white", fontweight="bold", zorder=5, ) # ── Group header bars ─────────────────────────────────────── header_y = -1.05 for gtitle, gcolor, gs, ge in group_spans: ax.add_patch(plt.Rectangle( (gs, header_y - 0.32), ge - gs, 0.5, color=gcolor, alpha=0.95, zorder=2, )) ax.text( (gs + ge) / 2, header_y - 0.07, gtitle, ha="center", va="center", color="white", fontsize=10.5, fontweight="bold", ) # ── Right-hand ranking panel ──────────────────────────────── bar_x0 = x bar_w = 3.0 rank_x = bar_x0 + bar_w + 1.1 # For bar length, normalise grand to [0,1] for display grand_valid = grand[~np.isnan(grand)] if len(grand_valid) > 0: if grand_higher_better: g_min, g_max = grand_valid.min(), grand_valid.max() else: # Invert for bar: lower rank → longer bar g_min, g_max = grand_valid.min(), grand_valid.max() else: g_min, g_max = 0, 1 for row, m in enumerate(models_sorted): mi = m2i[m] val = grand[mi] rank = row + 1 bcolor = "#1b2631" if rank == 1 else "#566573" # Bar length proportional to score if grand_higher_better: bar_frac = (val - g_min) / max(g_max - g_min, 1e-10) else: # For ranks: lower is better → invert bar bar_frac = 1.0 - (val - g_min) / max(g_max - g_min, 1e-10) bar_frac = max(bar_frac, 0.02) # minimum visible bar ax.add_patch(plt.Rectangle( (bar_x0, row - 0.28), bar_frac * bar_w, 0.56, color=bcolor, alpha=0.92, zorder=3, )) # Show the actual grand value if stat == "rank": grand_text = f"{val:.2f}" else: grand_text = f"{val:.2f}" ax.text( bar_x0 + bar_frac * bar_w + 0.08, row, grand_text, va="center", fontsize=10, fontweight="bold", ) # Ranking panel header ax.add_patch(plt.Rectangle( (bar_x0, header_y - 0.32), (rank_x - bar_x0) + 0.7, 0.5, color="#1b2631", alpha=0.95, zorder=2, )) ax.text( bar_x0 + ((rank_x - bar_x0) + 0.7) / 2, header_y - 0.07, "Overall ranking", ha="center", va="center", color="white", fontsize=10.5, fontweight="bold", ) # ── Axes formatting ───────────────────────────────────────── ax.set_xticks([c["x"] for c in cols]) ax.set_xticklabels( [ c["label"].capitalize() if c["kind"] != "overall" else "Overall" for c in cols ], rotation=40, fontsize=9.5, ha="right", ) ax.set_yticks(range(n_models)) ax.set_yticklabels( [model_labels.get(m, m) for m in models_sorted], fontsize=11, fontweight="bold", ) ax.invert_yaxis() ax.set_xlim(-0.9, rank_x + 1.0) ax.set_ylim(n_models - 0.5, header_y - 0.55) for sp in ("top", "right", "left", "bottom"): ax.spines[sp].set_visible(False) ax.tick_params(left=False, bottom=False) if legend: # ── Legend: circle size scale ─────────────────────────────── leg = fig.add_axes([0.14, -0.05, 0.38, 0.1]) leg.axis("off") leg.set_xlim(0, 5.5) leg.set_ylim(0, 1) leg.text( -0.05, 0.6, "Norm. score:", ha="right", va="center", fontsize=9.5, fontweight="bold", ) for i, s in enumerate([0.1, 0.3, 0.5, 0.7, 0.9, 1.0]): leg.scatter( i * 0.85 + 0.25, 0.6, s=60 + 1100 * s, color="#7f8c8d", alpha=0.78, edgecolor="black", linewidth=0.5, ) leg.text( i * 0.85 + 0.25, 0, f"{s:.1f}", ha="center", fontsize=8.5, ) # ── Legend: metric colours ────────────────────────────────── cl = fig.add_axes([0.55, -0.05, 0.35, 0.1]) cl.axis("off") n_legend = len(all_metric_keys) cl.set_xlim(0, max(n_legend * 1.3, 2)) cl.set_ylim(0, 1) for i, mk in enumerate(all_metric_keys): cl.scatter( i * 1.25 + 0.2, 0.6, s=380, color=metric_colors[mk], alpha=0.88, edgecolor="black", linewidth=0.5, ) cl.text( i * 1.25 + 0.45, 0.6, metric_labels.get(mk, mk.upper()), va="center", fontsize=9.5, fontweight="bold", ) if summary: # ── Print numeric summary ─────────────────────────────────── print(f"\nBenchmark ranking summary (stat={stat}):") print("-" * 70) header_parts = ["Rank", f"{'Model':<24}"] for mk in all_metric_keys: header_parts.append(f"{mk.upper():>8}") header_parts.append(f"{'GRAND':>8}") print(" ".join(header_parts)) print("-" * 70) for rk, m in enumerate(models_sorted, 1): mi = m2i[m] parts = [f" #{rk}", f" {model_labels.get(m, m):<24}"] for mk in all_metric_keys: ov = raw_overall[mk][mi] if mk == "time" and stat != "rank": parts.append(f"{ov:>8.1f}") else: parts.append(f"{ov:>8.3f}") parts.append(f"{grand[mi]:>8.3f}") print(" ".join(parts)) # ── Save / show ───────────────────────────────────────────── if save: os.makedirs(os.path.dirname(save) or ".", exist_ok=True) fig.savefig(save, dpi=300, bbox_inches="tight", facecolor="white") print(f"\n Saved: {save}") if show: plt.show() return fig