Skip to content

visualization

orchard.evaluation.visualization

Visualization utilities for model evaluation.

Provides formatted visual reports including training loss/accuracy curves, normalized confusion matrices, and sample prediction grids. Integrated with the PlotContext DTO for aesthetic and technical consistency.

show_predictions(model, loader, device, classes, save_path=None, ctx=None, n=None)

Visualize model predictions on a sample batch.

Coordinates data extraction, model inference, grid layout generation, and image post-processing. Highlights correct (green) vs. incorrect (red) predictions.

Parameters:

Name Type Description Default
model Module

Trained model to evaluate.

required
loader DataLoader[Any]

DataLoader providing evaluation samples.

required
device device

Target device for inference.

required
classes list[str]

Human-readable class label names.

required
save_path Path | None

Output file path. If None, displays interactively.

None
ctx PlotContext | None

PlotContext with layout and normalization settings.

None
n int | None

Number of samples to display. Defaults to ctx.n_samples.

None
Source code in orchard/evaluation/visualization.py
def show_predictions(
    model: nn.Module,
    loader: DataLoader[Any],
    device: torch.device,
    classes: list[str],
    save_path: Path | None = None,
    ctx: PlotContext | None = None,
    n: int | None = None,
) -> None:
    """
    Visualize model predictions on a sample batch.

    Coordinates data extraction, model inference, grid layout generation,
    and image post-processing. Highlights correct (green) vs. incorrect
    (red) predictions.

    Args:
        model: Trained model to evaluate.
        loader: DataLoader providing evaluation samples.
        device: Target device for inference.
        classes: Human-readable class label names.
        save_path: Output file path. If None, displays interactively.
        ctx: PlotContext with layout and normalization settings.
        n: Number of samples to display. Defaults to ``ctx.n_samples``.
    """
    model.eval()
    # cosmetic fallback
    style = ctx.plot_style if ctx else "seaborn-v0_8-muted"  # pragma: no mutate

    with plt.style.context(style):  # pragma: no mutate
        # 1. Parameter Resolution & Batch Inference
        # cosmetic fallback
        num_samples = n or (ctx.n_samples if ctx else 12)  # pragma: no mutate
        images, labels, preds = _get_predictions_batch(
            model, loader, device, num_samples  # pragma: no mutate
        )

        # 2. Grid & Figure Setup
        # cosmetic fallback
        grid_cols = ctx.grid_cols if ctx else 4  # pragma: no mutate
        _, axes = _setup_prediction_grid(len(images), grid_cols, ctx)  # pragma: no mutate

        # 3. Plotting Loop
        for i, ax in enumerate(axes):
            # guard for extra grid cells beyond actual images
            if i < len(images):  # pragma: no mutate
                _plot_single_prediction(
                    ax, images[i], labels[i], preds[i], classes, ctx  # pragma: no mutate
                )
            ax.axis("off")  # pragma: no mutate

        # 4. Suptitle
        if ctx:
            plt.suptitle(_build_suptitle(ctx), fontsize=14)  # pragma: no mutate

        # 5. Export and Cleanup
        # forwarding; tested in _finalize_figure
        _finalize_figure(plt, save_path, ctx)  # pragma: no mutate

plot_training_curves(train_losses, val_accuracies, out_path, ctx)

Plot training loss and validation accuracy on a dual-axis chart.

Saves the figure to disk and exports raw numerical data as .npz for reproducibility.

Parameters:

Name Type Description Default
train_losses Sequence[float]

Per-epoch training loss values.

required
val_accuracies Sequence[float]

Per-epoch validation accuracy values.

required
out_path Path

Destination file path for the saved figure.

required
ctx PlotContext

PlotContext with architecture and evaluation settings.

required
Source code in orchard/evaluation/visualization.py
def plot_training_curves(
    train_losses: Sequence[float], val_accuracies: Sequence[float], out_path: Path, ctx: PlotContext
) -> None:
    """
    Plot training loss and validation accuracy on a dual-axis chart.

    Saves the figure to disk and exports raw numerical data as ``.npz``
    for reproducibility.

    Args:
        train_losses: Per-epoch training loss values.
        val_accuracies: Per-epoch validation accuracy values.
        out_path: Destination file path for the saved figure.
        ctx: PlotContext with architecture and evaluation settings.
    """
    # matplotlib cosmetic — colors, fonts, sizes, layout
    with plt.style.context(ctx.plot_style):  # pragma: no mutate
        fig, ax1 = plt.subplots(figsize=(9, 6))  # pragma: no mutate

        # Left Axis: Training Loss
        ax1.plot(train_losses, color="#e74c3c", lw=2, label="Training Loss")  # pragma: no mutate
        ax1.set_xlabel("Epoch")  # pragma: no mutate
        ax1.set_ylabel("Loss", color="#e74c3c", fontweight="bold")  # pragma: no mutate
        ax1.tick_params(axis="y", labelcolor="#e74c3c")  # pragma: no mutate
        ax1.grid(True, linestyle="--", alpha=0.4)  # pragma: no mutate

        # Right Axis: Validation Accuracy
        ax2 = ax1.twinx()  # pragma: no mutate
        ax2.plot(  # pragma: no mutate
            val_accuracies, color="#3498db", lw=2, label="Validation Accuracy"  # pragma: no mutate
        )  # pragma: no mutate
        ax2.set_ylabel("Accuracy", color="#3498db", fontweight="bold")  # pragma: no mutate
        ax2.tick_params(axis="y", labelcolor="#3498db")  # pragma: no mutate

        fig.suptitle(  # pragma: no mutate
            f"Training Metrics — {ctx.arch_name} | Resolution — {ctx.resolution}",
            fontsize=14,  # pragma: no mutate
            y=1.02,  # pragma: no mutate
        )

        fig.tight_layout()  # pragma: no mutate

        plt.savefig(out_path, dpi=ctx.fig_dpi, bbox_inches="tight")  # pragma: no mutate
        logger.info(
            "%s%s %-18s: %s", LogStyle.INDENT, LogStyle.ARROW, "Training Curves", out_path.name
        )

        # Export raw data for post-run analysis
        npz_path = out_path.with_suffix(".npz")
        np.savez(npz_path, train_losses=train_losses, val_accuracies=val_accuracies)
        plt.close()

plot_confusion_matrix(all_labels, all_preds, classes, out_path, ctx)

Generate and save a row-normalized confusion matrix plot.

Parameters:

Name Type Description Default
all_labels NDArray[Any]

Ground-truth label array.

required
all_preds NDArray[Any]

Predicted label array.

required
classes list[str]

Human-readable class label names.

required
out_path Path

Destination file path for the saved figure.

required
ctx PlotContext

PlotContext with architecture and evaluation settings.

required
Source code in orchard/evaluation/visualization.py
def plot_confusion_matrix(
    all_labels: npt.NDArray[Any],
    all_preds: npt.NDArray[Any],
    classes: list[str],
    out_path: Path,
    ctx: PlotContext,
) -> None:
    """
    Generate and save a row-normalized confusion matrix plot.

    Args:
        all_labels: Ground-truth label array.
        all_preds: Predicted label array.
        classes: Human-readable class label names.
        out_path: Destination file path for the saved figure.
        ctx: PlotContext with architecture and evaluation settings.
    """
    # matplotlib cosmetic — confusion matrix rendering and styling
    with plt.style.context(ctx.plot_style):  # pragma: no mutate
        cm = confusion_matrix(  # pragma: no mutate
            all_labels,
            all_preds,
            labels=np.arange(len(classes)),
            normalize="true",  # pragma: no mutate
        )
        cm = np.nan_to_num(cm)

        disp = ConfusionMatrixDisplay(
            confusion_matrix=cm, display_labels=classes  # pragma: no mutate
        )
        fig, ax = plt.subplots(figsize=(11, 9))  # pragma: no mutate

        disp.plot(  # pragma: no mutate
            ax=ax,
            cmap=ctx.cmap_confusion,
            xticks_rotation=45,  # pragma: no mutate
            values_format=".3f",  # pragma: no mutate
        )  # pragma: no mutate
        plt.title(  # pragma: no mutate
            f"Confusion Matrix — {ctx.arch_name} | Resolution — {ctx.resolution}",
            fontsize=12,  # pragma: no mutate
            pad=20,  # pragma: no mutate
        )

        plt.tight_layout()  # pragma: no mutate

        fig.savefig(out_path, dpi=ctx.fig_dpi, bbox_inches="tight")  # pragma: no mutate
        plt.close()
        logger.info(
            "%s%s %-18s: %s", LogStyle.INDENT, LogStyle.ARROW, "Confusion Matrix", out_path.name
        )