Skip to content

evaluator

orchard.evaluation.evaluator

Evaluation Engine Module.

Runs batch-level inference on a labelled test set and consolidates predictions into global classification metrics (accuracy, macro F1, macro AUC). Supports optional Test-Time Augmentation via the tta sub-module, applying domain-aware transforms (anatomical, texture) and averaging softmax outputs across the ensemble.

Key Functions:

  • evaluate_model: Full-dataset evaluation with optional TTA, returning predictions, labels, metric dict, and macro F1.
Example

preds, labels, metrics, f1 = evaluate_model( ... model, test_loader, device, use_tta=True, cfg=cfg ... ) print(f"Test AUC: {metrics['auc']:.4f}")

evaluate_model(model, test_loader, device, use_tta=False, is_anatomical=False, is_texture_based=False, aug_cfg=None, resolution=28)

Performs full-set evaluation and coordinates metric calculation.

Parameters:

Name Type Description Default
model Module

The trained neural network.

required
test_loader DataLoader[Any]

DataLoader for the evaluation set.

required
device device

Hardware target (CPU/CUDA/MPS).

required
use_tta bool

Flag to enable Test-Time Augmentation.

False
is_anatomical bool

Dataset-specific orientation constraint.

False
is_texture_based bool

Dataset-specific texture preservation flag.

False
aug_cfg AugmentationConfig | None

Augmentation sub-configuration (required for TTA).

None
resolution int

Dataset resolution for TTA intensity scaling.

28

Returns:

Type Description
tuple[NDArray[Any], NDArray[Any], dict[str, float], float]

tuple[np.ndarray, np.ndarray, dict, float]: A 4-tuple of:

  • all_preds -- Predicted class indices, shape (N,)
  • all_labels -- Ground truth labels, shape (N,)
  • metrics -- dict[str, float] with keys accuracy, auc, f1
  • macro_f1 -- Macro-averaged F1 score (convenience shortcut)
Source code in orchard/evaluation/evaluator.py
def evaluate_model(
    model: nn.Module,
    test_loader: DataLoader[Any],
    device: torch.device,
    use_tta: bool = False,
    is_anatomical: bool = False,
    is_texture_based: bool = False,
    aug_cfg: AugmentationConfig | None = None,
    resolution: int = 28,
) -> tuple[npt.NDArray[Any], npt.NDArray[Any], dict[str, float], float]:
    """
    Performs full-set evaluation and coordinates metric calculation.

    Args:
        model: The trained neural network.
        test_loader: DataLoader for the evaluation set.
        device: Hardware target (CPU/CUDA/MPS).
        use_tta: Flag to enable Test-Time Augmentation.
        is_anatomical: Dataset-specific orientation constraint.
        is_texture_based: Dataset-specific texture preservation flag.
        aug_cfg: Augmentation sub-configuration (required for TTA).
        resolution: Dataset resolution for TTA intensity scaling.

    Returns:
        tuple[np.ndarray, np.ndarray, dict, float]: A 4-tuple of:

            - **all_preds** -- Predicted class indices, shape ``(N,)``
            - **all_labels** -- Ground truth labels, shape ``(N,)``
            - **metrics** -- dict[str, float] with keys ``accuracy``, ``auc``, ``f1``
            - **macro_f1** -- Macro-averaged F1 score (convenience shortcut)
    """
    model.eval()
    all_probs_list: list[npt.NDArray[Any]] = []
    all_labels_list: list[npt.NDArray[Any]] = []

    actual_tta = use_tta and (aug_cfg is not None)

    with torch.no_grad():
        for inputs, targets in test_loader:
            if actual_tta:
                # TTA logic handles its own device placement and softmax
                probs = adaptive_tta_predict(
                    model, inputs, device, is_anatomical, is_texture_based, aug_cfg, resolution
                )
            else:
                # Standard forward pass
                inputs = inputs.to(device)
                logits = model(inputs)
                probs = torch.softmax(logits, dim=1)

            all_probs_list.append(probs.cpu().numpy())
            all_labels_list.append(targets.numpy())

    # Consolidate batch results into global arrays
    all_probs = np.concatenate(all_probs_list)
    all_labels = np.concatenate(all_labels_list)
    all_preds = all_probs.argmax(axis=1)

    # Delegate statistical analysis to the metrics module
    metrics = compute_classification_metrics(all_labels, all_preds, all_probs)

    # Performance logging
    log_msg = "%s%s %-18s: Acc: %.4f | AUC: %.4f | F1: %.4f" % (
        LogStyle.INDENT,
        LogStyle.ARROW,
        "Test Metrics",
        metrics[METRIC_ACCURACY],
        metrics[METRIC_AUC],
        metrics[METRIC_F1],
    )
    if actual_tta and aug_cfg is not None:
        mode = aug_cfg.tta_mode.upper()
        log_msg += " | TTA ENABLED (Mode: %s)" % mode

    logger.info(log_msg)

    return all_preds, all_labels, metrics, metrics[METRIC_F1]