Skip to content

tta

orchard.evaluation.tta

Test-Time Augmentation (TTA) Module.

This module implements adaptive TTA strategies for robust inference. It provides an ensemble-based prediction mechanism that respects anatomical constraints and texture preservation requirements.

Transform selection is deterministic and hardware-independent: the same tta_mode config always produces the same ensemble regardless of whether inference runs on CPU, CUDA, or MPS, guaranteeing cross-platform reproducibility.

adaptive_tta_predict(model, inputs, device, is_anatomical, is_texture_based, aug_cfg, resolution)

Performs Test-Time Augmentation (TTA) inference on a batch of inputs.

Applies a set of standard augmentations in addition to the original input. Predictions from all augmented versions are averaged in the probability space. If is_anatomical is True, it restricts augmentations to orientation-preserving transforms. If is_texture_based is True, it disables destructive pixel-level noise/blur to preserve local patterns. The tta_mode config field controls ensemble complexity (full vs light) independently of hardware.

Parameters:

Name Type Description Default
model Module

The trained PyTorch model.

required
inputs Tensor

The batch of test images.

required
device device

The device to run the inference on.

required
is_anatomical bool

Whether the dataset has fixed anatomical orientation.

required
is_texture_based bool

Whether the dataset relies on high-frequency textures.

required
aug_cfg AugmentationConfig

Augmentation sub-configuration with TTA parameters.

required
resolution int

Dataset resolution for TTA intensity scaling.

required

Returns:

Type Description
Tensor

The averaged softmax probability predictions (mean ensemble).

Source code in orchard/evaluation/tta.py
def adaptive_tta_predict(
    model: nn.Module,
    inputs: torch.Tensor,
    device: torch.device,
    is_anatomical: bool,
    is_texture_based: bool,
    aug_cfg: AugmentationConfig,
    resolution: int,
) -> torch.Tensor:
    """
    Performs Test-Time Augmentation (TTA) inference on a batch of inputs.

    Applies a set of standard augmentations in addition to the original input.
    Predictions from all augmented versions are averaged in the probability space.
    If is_anatomical is True, it restricts augmentations to orientation-preserving
    transforms. If is_texture_based is True, it disables destructive pixel-level
    noise/blur to preserve local patterns. The ``tta_mode`` config field controls
    ensemble complexity (full vs light) independently of hardware.

    Args:
        model: The trained PyTorch model.
        inputs: The batch of test images.
        device: The device to run the inference on.
        is_anatomical: Whether the dataset has fixed anatomical orientation.
        is_texture_based: Whether the dataset relies on high-frequency textures.
        aug_cfg: Augmentation sub-configuration with TTA parameters.
        resolution: Dataset resolution for TTA intensity scaling.

    Returns:
        The averaged softmax probability predictions (mean ensemble).
    """
    model.eval()
    inputs = inputs.to(device)

    # Generate the suite of transforms via module-level factory
    transforms = _get_tta_transforms(is_anatomical, is_texture_based, aug_cfg, resolution)

    # ENSEMBLE EXECUTION: Iterative probability accumulation to save VRAM
    ensemble_probs = None

    with torch.no_grad():
        for t in transforms:
            aug_input = t(inputs)
            logits = model(aug_input)
            probs = F.softmax(logits, dim=1)

            if ensemble_probs is None:
                ensemble_probs = probs
            else:
                ensemble_probs += probs

    # Calculate the mean probability across all augmentation passes
    if ensemble_probs is None:
        raise ValueError("TTA transforms list cannot be empty")
    return ensemble_probs / len(transforms)