Skip to content

visualization

orchard.evaluation.visualization

Visualization utilities for model evaluation.

Provides formatted visual reports including training loss/accuracy curves, normalized confusion matrices, and sample prediction grids. Integrated with the PlotContext DTO for aesthetic and technical consistency.

show_predictions(model, loader, device, classes, save_path=None, ctx=None, n=None)

Visualize model predictions on a sample batch.

Coordinates data extraction, model inference, grid layout generation, and image post-processing. Highlights correct (green) vs. incorrect (red) predictions.

Parameters:

Name Type Description Default
model Module

Trained model to evaluate.

required
loader DataLoader[Any]

DataLoader providing evaluation samples.

required
device device

Target device for inference.

required
classes list[str]

Human-readable class label names.

required
save_path Path | None

Output file path. If None, displays interactively.

None
ctx PlotContext | None

PlotContext with layout and normalization settings.

None
n int | None

Number of samples to display. Defaults to ctx.n_samples.

None
Source code in orchard/evaluation/visualization.py
def show_predictions(
    model: nn.Module,
    loader: DataLoader[Any],
    device: torch.device,
    classes: list[str],
    save_path: Path | None = None,
    ctx: PlotContext | None = None,
    n: int | None = None,
) -> None:
    """
    Visualize model predictions on a sample batch.

    Coordinates data extraction, model inference, grid layout generation,
    and image post-processing. Highlights correct (green) vs. incorrect
    (red) predictions.

    Args:
        model: Trained model to evaluate.
        loader: DataLoader providing evaluation samples.
        device: Target device for inference.
        classes: Human-readable class label names.
        save_path: Output file path. If None, displays interactively.
        ctx: PlotContext with layout and normalization settings.
        n: Number of samples to display. Defaults to ``ctx.n_samples``.
    """
    model.eval()
    style = ctx.plot_style if ctx else _DEFAULT_STYLE

    with plt.style.context(style):
        # 1. Parameter Resolution & Batch Inference
        num_samples = n or (ctx.n_samples if ctx else _DEFAULT_NUM_SAMPLES)
        images, labels, preds = _get_predictions_batch(model, loader, device, num_samples)

        # 2. Grid & Figure Setup
        grid_cols = ctx.grid_cols if ctx else _DEFAULT_GRID_COLS
        _, axes = _setup_prediction_grid(len(images), grid_cols, ctx)

        # 3. Plotting Loop
        for i, ax in enumerate(axes):
            # guard for extra grid cells beyond actual images
            if i < len(images):
                _plot_single_prediction(ax, images[i], labels[i], preds[i], classes, ctx)
            ax.axis("off")

        # 4. Suptitle
        if ctx:
            plt.suptitle(_build_suptitle(ctx), fontsize=_SUPTITLE_FONTSIZE)

        # 5. Export and Cleanup
        _finalize_figure(plt, save_path, ctx)

plot_training_curves(train_losses, val_metric_values, out_path, ctx, *, val_label)

Plot training loss and a validation metric on a dual-axis chart.

Saves the figure to disk and exports raw numerical data as .npz for reproducibility.

Parameters:

Name Type Description Default
train_losses Sequence[float]

Per-epoch training loss values.

required
val_metric_values Sequence[float]

Per-epoch validation metric values.

required
out_path Path

Destination file path for the saved figure.

required
ctx PlotContext

PlotContext with architecture and evaluation settings.

required
val_label str

Label for the right y-axis and legend entry.

required
Source code in orchard/evaluation/visualization.py
def plot_training_curves(
    train_losses: Sequence[float],
    val_metric_values: Sequence[float],
    out_path: Path,
    ctx: PlotContext,
    *,
    val_label: str,
) -> None:
    """
    Plot training loss and a validation metric on a dual-axis chart.

    Saves the figure to disk and exports raw numerical data as ``.npz``
    for reproducibility.

    Args:
        train_losses: Per-epoch training loss values.
        val_metric_values: Per-epoch validation metric values.
        out_path: Destination file path for the saved figure.
        ctx: PlotContext with architecture and evaluation settings.
        val_label: Label for the right y-axis and legend entry.
    """
    with plt.style.context(ctx.plot_style):
        fig, ax1 = plt.subplots(figsize=_TRAINING_FIGSIZE)

        # Left Axis: Training Loss
        ax1.plot(train_losses, color=_LOSS_COLOR, lw=_LINE_WIDTH, label="Training Loss")
        ax1.set_xlabel("Epoch")
        ax1.set_ylabel("Loss", color=_LOSS_COLOR, fontweight="bold")
        ax1.tick_params(axis="y", labelcolor=_LOSS_COLOR)
        ax1.grid(True, linestyle=_GRID_LINESTYLE, alpha=_GRID_ALPHA)

        # Right Axis: Validation Metric
        ax2 = ax1.twinx()
        ax2.plot(val_metric_values, color=_METRIC_COLOR, lw=_LINE_WIDTH, label=val_label)
        ax2.set_ylabel(val_label, color=_METRIC_COLOR, fontweight="bold")
        ax2.tick_params(axis="y", labelcolor=_METRIC_COLOR)

        fig.suptitle(
            f"Training Metrics — {ctx.arch_name} | Resolution — {ctx.resolution}",
            fontsize=_SUPTITLE_FONTSIZE,
            y=_TRAINING_TITLE_Y,
        )

        fig.tight_layout()

        plt.savefig(out_path, dpi=ctx.fig_dpi, bbox_inches=_SAVEFIG_BBOX)
        logger.info(
            "%s%s %-18s: %s", LogStyle.INDENT, LogStyle.ARROW, "Training Curves", out_path.name
        )

        # Export raw data for post-run analysis
        npz_path = out_path.with_suffix(".npz")
        np.savez(npz_path, train_losses=train_losses, val_metric_values=val_metric_values)
        plt.close()

plot_confusion_matrix(all_labels, all_preds, classes, out_path, ctx)

Generate and save a row-normalized confusion matrix plot.

Parameters:

Name Type Description Default
all_labels NDArray[Any]

Ground-truth label array.

required
all_preds NDArray[Any]

Predicted label array.

required
classes list[str]

Human-readable class label names.

required
out_path Path

Destination file path for the saved figure.

required
ctx PlotContext

PlotContext with architecture and evaluation settings.

required
Source code in orchard/evaluation/visualization.py
def plot_confusion_matrix(
    all_labels: npt.NDArray[Any],
    all_preds: npt.NDArray[Any],
    classes: list[str],
    out_path: Path,
    ctx: PlotContext,
) -> None:
    """
    Generate and save a row-normalized confusion matrix plot.

    Args:
        all_labels: Ground-truth label array.
        all_preds: Predicted label array.
        classes: Human-readable class label names.
        out_path: Destination file path for the saved figure.
        ctx: PlotContext with architecture and evaluation settings.
    """
    with plt.style.context(ctx.plot_style):
        cm = confusion_matrix(
            all_labels,
            all_preds,
            labels=np.arange(len(classes)),
            normalize="true",
        )
        cm = np.nan_to_num(cm)

        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
        fig, ax = plt.subplots(figsize=_CM_FIGSIZE)

        disp.plot(
            ax=ax,
            cmap=ctx.cmap_confusion,
            xticks_rotation=_CM_TICKS_ROTATION,
            values_format=_CM_VALUES_FORMAT,
        )
        plt.title(
            f"Confusion Matrix — {ctx.arch_name} | Resolution — {ctx.resolution}",
            fontsize=_CM_TITLE_FONTSIZE,
            pad=_CM_TITLE_PAD,
        )

        plt.tight_layout()

        fig.savefig(out_path, dpi=ctx.fig_dpi, bbox_inches=_SAVEFIG_BBOX)
        plt.close()
        logger.info(
            "%s%s %-18s: %s", LogStyle.INDENT, LogStyle.ARROW, "Confusion Matrix", out_path.name
        )