tta
orchard.evaluation.tta
¶
Test-Time Augmentation (TTA) Module.
This module implements adaptive TTA strategies for robust inference. It provides an ensemble-based prediction mechanism that respects anatomical constraints and texture preservation requirements.
Transform selection is deterministic and hardware-independent: the same
tta_mode config always produces the same ensemble regardless of whether
inference runs on CPU, CUDA, or MPS, guaranteeing cross-platform reproducibility.
adaptive_tta_predict(model, inputs, device, is_anatomical, is_texture_based, aug_cfg, resolution)
¶
Performs Test-Time Augmentation (TTA) inference on a batch of inputs.
Applies a set of standard augmentations in addition to the original input.
Predictions from all augmented versions are averaged in the probability space.
If is_anatomical is True, it restricts augmentations to orientation-preserving
transforms. If is_texture_based is True, it disables destructive pixel-level
noise/blur to preserve local patterns. The tta_mode config field controls
ensemble complexity (full vs light) independently of hardware.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The trained PyTorch model. |
required |
inputs
|
Tensor
|
The batch of test images. |
required |
device
|
device
|
The device to run the inference on. |
required |
is_anatomical
|
bool
|
Whether the dataset has fixed anatomical orientation. |
required |
is_texture_based
|
bool
|
Whether the dataset relies on high-frequency textures. |
required |
aug_cfg
|
AugmentationConfig
|
Augmentation sub-configuration with TTA parameters. |
required |
resolution
|
int
|
Dataset resolution for TTA intensity scaling. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The averaged softmax probability predictions (mean ensemble). |