Skip to content

trainer

orchard.trainer

Trainer Package Facade.

This package exposes the central ModelTrainer class, the optimization factories, and the low-level execution engines, providing a unified interface for the training lifecycle.

LoopOptions(grad_clip, total_epochs, mixup_epochs, use_tqdm, monitor_metric) dataclass

Scalar configuration for a :class:TrainingLoop.

Groups training hyper-parameters that do not depend on PyTorch objects, keeping the TrainingLoop constructor lean.

Attributes:

Name Type Description
grad_clip float | None

Max norm for gradient clipping (0 or None disables).

total_epochs int

Total number of epochs (for tqdm progress bar).

mixup_epochs int

Epoch cutoff after which MixUp is disabled.

use_tqdm bool

Whether to show tqdm progress bar.

monitor_metric str

Metric key for ReduceLROnPlateau stepping (e.g. "auc", "accuracy").

TrainingLoop(model, train_loader, val_loader, optimizer, scheduler, criterion, device, scaler, mixup_fn, options)

Single-epoch execution kernel shared by ModelTrainer and TrialTrainingExecutor.

Encapsulates the per-epoch train/validate/schedule cycle. Callers own the outer epoch loop and policy decisions (checkpointing, early stopping, Optuna pruning). This class only executes one epoch at a time.

Attributes:

Name Type Description
model Module

Neural network to train.

train_loader DataLoader

Training data provider.

val_loader DataLoader

Validation data provider.

optimizer Optimizer

Gradient descent optimizer.

scheduler LRScheduler | None

Learning rate scheduler (or None).

criterion Module

Loss function.

device device

Hardware target (CUDA/MPS/CPU).

scaler GradScaler | None

AMP GradScaler (or None).

mixup_fn Callable | None

MixUp partial function (or None).

options LoopOptions

Scalar training options (see :class:LoopOptions).

Source code in orchard/trainer/_loop.py
def __init__(
    self,
    model: nn.Module,
    train_loader: torch.utils.data.DataLoader[Any],
    val_loader: torch.utils.data.DataLoader[Any],
    optimizer: torch.optim.Optimizer,
    scheduler: LRScheduler | None,
    criterion: nn.Module,
    device: torch.device,
    scaler: torch.amp.grad_scaler.GradScaler | None,
    mixup_fn: Callable[..., Any] | None,
    options: LoopOptions,
) -> None:
    self.model = model
    self.train_loader = train_loader
    self.val_loader = val_loader
    self.optimizer = optimizer
    self.scheduler = scheduler
    self.criterion = criterion
    self.device = device
    self.scaler = scaler
    self.mixup_fn = mixup_fn
    self.options = options

run_train_step(epoch)

Execute a single training epoch with MixUp cutoff.

Applies MixUp augmentation only when epoch <= mixup_epochs. Does not run validation or step the scheduler.

Parameters:

Name Type Description Default
epoch int

Current epoch number (1-indexed).

required

Returns:

Type Description
float

Average training loss for the epoch.

Source code in orchard/trainer/_loop.py
def run_train_step(self, epoch: int) -> float:
    """
    Execute a single training epoch with MixUp cutoff.

    Applies MixUp augmentation only when ``epoch <= mixup_epochs``.
    Does **not** run validation or step the scheduler.

    Args:
        epoch: Current epoch number (1-indexed).

    Returns:
        Average training loss for the epoch.
    """
    current_mixup = self.mixup_fn if epoch <= self.options.mixup_epochs else None
    return train_one_epoch(
        model=self.model,
        loader=self.train_loader,
        criterion=self.criterion,
        optimizer=self.optimizer,
        device=self.device,
        mixup_fn=current_mixup,
        scaler=self.scaler,
        grad_clip=self.options.grad_clip,
        epoch=epoch,
        total_epochs=self.options.total_epochs,
        use_tqdm=self.options.use_tqdm,
    )

run_epoch(epoch)

Execute a full train → validate → schedule cycle for one epoch.

Parameters:

Name Type Description Default
epoch int

Current epoch number (1-indexed).

required

Returns:

Type Description
tuple[float, Mapping[str, float]]

Tuple of (average training loss, validation metrics dict).

Source code in orchard/trainer/_loop.py
def run_epoch(self, epoch: int) -> tuple[float, Mapping[str, float]]:
    """
    Execute a full train → validate → schedule cycle for one epoch.

    Args:
        epoch: Current epoch number (1-indexed).

    Returns:
        Tuple of (average training loss, validation metrics dict).
    """
    train_loss = self.run_train_step(epoch)
    val_metrics = validate_epoch(
        model=self.model,
        val_loader=self.val_loader,
        criterion=self.criterion,
        device=self.device,
    )
    monitor = self.options.monitor_metric
    if monitor not in val_metrics:
        raise KeyError(
            f"Monitor metric '{monitor}' not found in validation results. "
            f"Available: {list(val_metrics.keys())}"
        )
    step_scheduler(self.scheduler, val_metrics[monitor])
    return train_loss, val_metrics

ModelTrainer(model, train_loader, val_loader, optimizer, scheduler, criterion, device, training, output_path=None, tracker=None)

Encapsulates the core training, validation, and scheduling logic.

Manages the complete training lifecycle including epoch iteration, metric tracking, automated checkpointing based on validation performance, and early stopping with patience-based criteria. Integrates modern training techniques (AMP, Mixup, gradient clipping) and ensures deterministic model restoration to best-performing weights.

The trainer follows a structured execution flow:

  1. Training Phase: Forward/backward passes with optional Mixup augmentation
  2. Validation Phase: Performance evaluation on held-out data
  3. Scheduling Phase: Learning rate updates (ReduceLROnPlateau or step-based)
  4. Checkpointing: Save model when monitor_metric improves
  5. Early Stopping: Halt training if no improvement for patience epochs

Attributes:

Name Type Description
model

Neural network architecture to train.

train_loader

Training data provider.

val_loader

Validation data provider.

optimizer

Gradient descent optimizer.

scheduler

Learning rate scheduler.

criterion

Loss function (e.g., CrossEntropyLoss).

device

Hardware target (CUDA/MPS/CPU).

training

Training hyperparameters sub-config.

epochs

Total number of training epochs.

patience

Early stopping patience (epochs without improvement).

best_acc

Best validation accuracy achieved.

best_metric

Best value of the monitored metric.

epochs_no_improve

Consecutive epochs without monitored metric improvement.

scaler

AMP scaler (None when use_amp is False).

mixup_fn

Mixup augmentation function (partial of mixup_data).

best_path

Filesystem path for best model checkpoint.

train_losses list[float]

Training loss history per epoch.

val_metrics_history list[Mapping[str, float]]

Validation metrics history per epoch.

monitor_metric

Name of metric driving checkpointing.

_loop

Shared epoch kernel handling train → validate → schedule.

Example

from orchard.trainer import ModelTrainer trainer = ModelTrainer( ... model=model, ... train_loader=train_loader, ... val_loader=val_loader, ... optimizer=optimizer, ... scheduler=scheduler, ... criterion=criterion, ... device=device, ... training=cfg.training, ... output_path=paths.checkpoints / "best_model.pth" ... ) checkpoint_path, losses, metrics = trainer.train()

Model automatically restored to best weights

Initializes the ModelTrainer with all required training components.

Parameters:

Name Type Description Default
model Module

Neural network architecture to train.

required
train_loader DataLoader[Any]

DataLoader for training dataset.

required
val_loader DataLoader[Any]

DataLoader for validation dataset.

required
optimizer Optimizer

Gradient descent optimizer (e.g., SGD, AdamW).

required
scheduler LRScheduler

Learning rate scheduler for training dynamics.

required
criterion Module

Loss function for optimisation (e.g., CrossEntropyLoss).

required
device device

Compute device for training.

required
training TrainingConfig

Training hyperparameters sub-config.

required
output_path Path | None

Path for best model checkpoint (default: ./best_model.pth).

None
tracker TrackerProtocol | None

Optional experiment tracker for MLflow metric logging.

None
Source code in orchard/trainer/trainer.py
def __init__(
    self,
    model: nn.Module,
    train_loader: DataLoader[Any],
    val_loader: DataLoader[Any],
    optimizer: torch.optim.Optimizer,
    scheduler: LRScheduler,
    criterion: nn.Module,
    device: torch.device,
    training: TrainingConfig,
    output_path: Path | None = None,
    tracker: TrackerProtocol | None = None,
) -> None:
    """
    Initializes the ModelTrainer with all required training components.

    Args:
        model: Neural network architecture to train.
        train_loader: DataLoader for training dataset.
        val_loader: DataLoader for validation dataset.
        optimizer: Gradient descent optimizer (e.g., SGD, AdamW).
        scheduler: Learning rate scheduler for training dynamics.
        criterion: Loss function for optimisation (e.g., CrossEntropyLoss).
        device: Compute device for training.
        training: Training hyperparameters sub-config.
        output_path: Path for best model checkpoint (default: ``./best_model.pth``).
        tracker: Optional experiment tracker for MLflow metric logging.
    """
    self.model = model
    self.train_loader = train_loader
    self.val_loader = val_loader
    self.optimizer = optimizer
    self.scheduler = scheduler
    self.criterion = criterion
    self.device = device
    self.training = training
    self.tracker = tracker

    # Hyperparameters
    self.epochs = training.epochs
    self.patience = training.patience
    self.monitor_metric = training.monitor_metric
    self.best_acc = -1.0  # Logging-only: always shown in summary regardless of monitor_metric
    self.best_metric = -float("inf")
    self.epochs_no_improve = 0

    # AMP and MixUp (shared factories from _loop)
    self.scaler = create_amp_scaler(training, device=str(device))
    self.mixup_fn = create_mixup_fn(training)

    # Output Management
    self.best_path = output_path or Path("./best_model.pth")
    self.best_path.parent.mkdir(parents=True, exist_ok=True)

    # History tracking
    self.train_losses: list[float] = []
    self.val_metrics_history: list[Mapping[str, float]] = []

    # Track if we saved at least one valid checkpoint during training
    self._checkpoint_saved: bool = False

    # Shared epoch kernel (train → validate → schedule)
    self._loop = TrainingLoop(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        device=device,
        scaler=self.scaler,
        mixup_fn=self.mixup_fn,
        options=LoopOptions(
            grad_clip=training.grad_clip,
            total_epochs=self.epochs,
            mixup_epochs=training.mixup_epochs,
            use_tqdm=training.use_tqdm,
            monitor_metric=self.monitor_metric,
        ),
    )

    logger.info(
        "%s%s %-18s: %s",
        LogStyle.INDENT,
        LogStyle.ARROW,
        "Checkpoint",
        self.best_path.name,
    )
    logger.info("")

train()

Executes the main training loop with checkpointing and early stopping.

Performs iterative training across configured epochs, executing:

  • Forward/backward passes with optional Mixup augmentation
  • Validation metric computation (loss, accuracy, AUC)
  • Learning rate scheduling (plateau-aware or step-based)
  • Automated checkpointing on monitor_metric improvement
  • Early stopping with patience-based criteria

Returns:

Type Description
tuple[Path, list[float], list[Mapping[str, float]]]

tuple containing:

  • Path: Filesystem path to best model checkpoint
  • list[float]: Training loss history per epoch
  • list[dict]: Validation metrics history (loss, accuracy, AUC per epoch)

Notes:

  • Model weights are automatically restored to best checkpoint after training
  • Mixup augmentation is disabled after mixup_epochs
  • Early stopping triggers if no monitor_metric improvement for patience epochs
Source code in orchard/trainer/trainer.py
def train(self) -> tuple[Path, list[float], list[Mapping[str, float]]]:
    """
    Executes the main training loop with checkpointing and early stopping.

    Performs iterative training across configured epochs, executing:

    - Forward/backward passes with optional Mixup augmentation
    - Validation metric computation (loss, accuracy, AUC)
    - Learning rate scheduling (plateau-aware or step-based)
    - Automated checkpointing on monitor_metric improvement
    - Early stopping with patience-based criteria

    Returns:
        tuple containing:

    - Path: Filesystem path to best model checkpoint
    - list[float]: Training loss history per epoch
    - list[dict]: Validation metrics history (loss, accuracy, AUC per epoch)

    Notes:

    - Model weights are automatically restored to best checkpoint after training
    - Mixup augmentation is disabled after mixup_epochs
    - Early stopping triggers if no monitor_metric improvement for `patience` epochs
    """
    for epoch in range(1, self.epochs + 1):
        header = " Epoch %02d/%d " % (epoch, self.epochs)  # pragma: no mutate
        logger.info(header.center(LogStyle.HEADER_WIDTH, "-"))

        # --- 1. Train → Validate → Schedule (delegated to _loop) ---
        epoch_loss, val_metrics = self._loop.run_epoch(epoch)
        self.train_losses.append(epoch_loss)
        self.val_metrics_history.append(val_metrics)

        val_acc = val_metrics[METRIC_ACCURACY]
        val_loss = val_metrics[METRIC_LOSS]  # Informational: logged but not used for decisions
        monitor_value = val_metrics[self.monitor_metric]

        if val_acc > self.best_acc:
            self.best_acc = val_acc

        # --- 2. Checkpointing ---
        early_stop = self._handle_checkpointing(val_metrics)

        # --- 3. Epoch Summary ---
        current_lr = self.optimizer.param_groups[0]["lr"]
        self._log_epoch_summary(
            epoch,
            epoch_loss,
            val_loss,
            val_acc,
            monitor_value,
            current_lr,
        )

        # --- 4. Experiment Tracking ---
        if self.tracker is not None:
            self.tracker.log_epoch(epoch, epoch_loss, val_metrics, current_lr)

        # --- 5. Early Stopping ---
        if early_stop:
            logger.warning("Early stopping triggered at epoch %d.", epoch)
            break

    self._log_training_complete()
    self._finalize_weights()

    return self.best_path, self.train_losses, self.val_metrics_history

load_best_weights()

Load the best checkpoint from disk into the model (device-aware).

Raises:

Type Description
RuntimeError

If the state-dict is incompatible with the model.

OrchardExportError

If the checkpoint file does not exist.

Source code in orchard/trainer/trainer.py
def load_best_weights(self) -> None:
    """
    Load the best checkpoint from disk into the model (device-aware).

    Raises:
        RuntimeError: If the state-dict is incompatible with the model.
        OrchardExportError: If the checkpoint file does not exist.
    """
    try:
        load_model_weights(model=self.model, path=self.best_path, device=self.device)
        logger.info("%s%s Model state restored", LogStyle.INDENT, LogStyle.SUCCESS)
        logger.info(
            "%s%s %-18s: %s",
            LogStyle.INDENT,
            LogStyle.ARROW,
            "Checkpoint",
            self.best_path.name,
        )
    except (RuntimeError, OrchardExportError) as e:
        logger.error("%s%s Weight restoration failed: %s", LogStyle.INDENT, LogStyle.FAILURE, e)
        raise

create_amp_scaler(training, device='cuda')

Create AMP GradScaler if mixed precision is enabled.

Parameters:

Name Type Description Default
training TrainingConfig

Training sub-config (reads use_amp).

required
device str

Target device string ("cuda" or "mps").

'cuda'

Returns:

Type Description
GradScaler | None

GradScaler instance when AMP is enabled, None otherwise.

Source code in orchard/trainer/_loop.py
def create_amp_scaler(
    training: TrainingConfig, device: str = "cuda"  # pragma: no mutate
) -> torch.amp.grad_scaler.GradScaler | None:
    """
    Create AMP GradScaler if mixed precision is enabled.

    Args:
        training: Training sub-config (reads ``use_amp``).
        device: Target device string (``"cuda"`` or ``"mps"``).

    Returns:
        GradScaler instance when AMP is enabled, None otherwise.
    """
    return torch.amp.grad_scaler.GradScaler(device=device) if training.use_amp else None

create_mixup_fn(training)

Create a seeded MixUp partial function if alpha > 0.

Parameters:

Name Type Description Default
training TrainingConfig

Training sub-config (reads mixup_alpha and seed).

required

Returns:

Type Description
Callable[..., Any] | None

Partial of mixup_data with fixed alpha and seeded RNG,

Callable[..., Any] | None

or None when MixUp is disabled.

Source code in orchard/trainer/_loop.py
def create_mixup_fn(training: TrainingConfig) -> Callable[..., Any] | None:
    """
    Create a seeded MixUp partial function if alpha > 0.

    Args:
        training: Training sub-config (reads ``mixup_alpha``
            and ``seed``).

    Returns:
        Partial of ``mixup_data`` with fixed alpha and seeded RNG,
        or None when MixUp is disabled.
    """
    if training.mixup_alpha > 0:
        rng = np.random.default_rng(training.seed)
        return partial(mixup_data, alpha=training.mixup_alpha, rng=rng)
    return None

compute_auc(y_true, y_score)

Compute macro-averaged ROC-AUC with graceful fallback.

Handles binary (positive class probability) and multiclass (OvR) cases. Returns NaN on failure so callers can distinguish "computation impossible" from "genuinely zero AUC".

Parameters:

Name Type Description Default
y_true NDArray[Any]

Ground truth class indices, shape (N,).

required
y_score NDArray[Any]

Probability distributions, shape (N, C) (softmax output).

required

Returns:

Type Description
float

ROC-AUC score, or NaN if computation fails.

Source code in orchard/trainer/engine.py
def compute_auc(y_true: npt.NDArray[Any], y_score: npt.NDArray[Any]) -> float:
    """
    Compute macro-averaged ROC-AUC with graceful fallback.

    Handles binary (positive class probability) and multiclass (OvR)
    cases. Returns ``NaN`` on failure so callers can distinguish
    "computation impossible" from "genuinely zero AUC".

    Args:
        y_true: Ground truth class indices, shape ``(N,)``.
        y_score: Probability distributions, shape ``(N, C)`` (softmax output).

    Returns:
        ROC-AUC score, or ``NaN`` if computation fails.
    """
    try:
        n_classes = y_score.shape[1] if y_score.ndim == 2 else 1
        if n_classes <= 2:
            auc = roc_auc_score(y_true, y_score[:, 1] if y_score.ndim == 2 else y_score)
        else:
            auc = roc_auc_score(y_true, y_score, multi_class="ovr", average="macro")
    except (ValueError, TypeError, IndexError) as e:
        logger.warning("ROC-AUC calculation failed: %s. Returning NaN.", e)
        return float("nan")

    if np.isnan(auc):
        return float("nan")
    return float(auc)

mixup_data(x, y, alpha=1.0, rng=None)

Applies MixUp augmentation by blending two random samples.

MixUp generates convex combinations of training pairs to improve generalization and calibration.

Parameters:

Name Type Description Default
x Tensor

Input data batch (images)

required
y Tensor

Target labels batch

required
alpha float

Beta distribution parameter (0 disables MixUp)

1.0
rng Generator | None

NumPy random generator for reproducibility (seeded from config)

None

Returns:

Type Description
tuple[Tensor, Tensor, Tensor, float]

4-tuple of (mixed_x, y_a, y_b, lam).

Source code in orchard/trainer/engine.py
def mixup_data(
    x: torch.Tensor,
    y: torch.Tensor,
    alpha: float = 1.0,
    rng: np.random.Generator | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
    """
    Applies MixUp augmentation by blending two random samples.

    MixUp generates convex combinations of training pairs to improve
    generalization and calibration.

    Args:
        x: Input data batch (images)
        y: Target labels batch
        alpha: Beta distribution parameter (0 disables MixUp)
        rng: NumPy random generator for reproducibility (seeded from config)

    Returns:
        4-tuple of (mixed_x, y_a, y_b, lam).
    """
    if alpha <= 0:
        return x, y, y, 1.0

    if rng is None:
        # Defensive fallback — production path always provides a seeded rng
        # via ModelTrainer (seeded from cfg.training.seed).
        rng = np.random.default_rng(seed=42)

    # Draw mixing coefficient from Beta distribution
    lam: float = float(rng.beta(alpha, alpha))
    batch_size: int = x.size(0)

    # Generate random permutation with a seeded generator for reproducibility
    g = torch.Generator(device=x.device)
    g.manual_seed(int(rng.integers(2**31)))
    index = torch.randperm(batch_size, device=x.device, generator=g)

    # Create mixed input
    mixed_x: torch.Tensor = lam * x + (1 - lam) * x[index]

    # Get corresponding targets
    y_a: torch.Tensor = y
    y_b: torch.Tensor = y[index]

    return mixed_x, y_a, y_b, lam

train_one_epoch(model, loader, criterion, optimizer, device, mixup_fn=None, scaler=None, grad_clip=0.0, epoch=0, total_epochs=1, use_tqdm=True)

Performs a single full pass over the training dataset.

Parameters:

Name Type Description Default
model Module

Neural network architecture to train

required
loader DataLoader[Any]

Training data provider

required
criterion Module

Loss function

required
optimizer Optimizer

Gradient descent optimizer

required
device device

Hardware target (CUDA/MPS/CPU)

required
mixup_fn Callable[..., Any] | None

Function to apply MixUp data blending (optional)

None
scaler GradScaler | None

PyTorch GradScaler for mixed precision training (optional)

None
grad_clip float | None

Max norm for gradient clipping (0 disables)

0.0
epoch int

Current epoch index for progress bar

0
total_epochs int

Total number of epochs (for progress bar)

1
use_tqdm bool

Show progress bar during training

True

Returns:

Type Description
float

Average training loss for the epoch

Source code in orchard/trainer/engine.py
def train_one_epoch(
    model: nn.Module,
    loader: torch.utils.data.DataLoader[Any],
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    mixup_fn: Callable[..., Any] | None = None,
    scaler: torch.amp.grad_scaler.GradScaler | None = None,
    grad_clip: float | None = 0.0,
    epoch: int = 0,
    total_epochs: int = 1,
    use_tqdm: bool = True,
) -> float:
    """
    Performs a single full pass over the training dataset.

    Args:
        model: Neural network architecture to train
        loader: Training data provider
        criterion: Loss function
        optimizer: Gradient descent optimizer
        device: Hardware target (CUDA/MPS/CPU)
        mixup_fn: Function to apply MixUp data blending (optional)
        scaler: PyTorch GradScaler for mixed precision training (optional)
        grad_clip: Max norm for gradient clipping (0 disables)
        epoch: Current epoch index for progress bar
        total_epochs: Total number of epochs (for progress bar)
        use_tqdm: Show progress bar during training

    Returns:
        Average training loss for the epoch
    """
    model.train()
    running_loss = 0.0
    total_samples = 0

    # Create iterator with or without progress bar
    if use_tqdm:
        iterator = tqdm(loader, desc=f"Train Epoch {epoch}/{total_epochs}", leave=True, ncols=100)
    else:
        iterator = loader

    # Resolve autocast device type for AMP
    amp_enabled = scaler is not None
    amp_device_type = device.type if amp_enabled else "cpu"

    # Training loop - iterate directly without enumerate
    for inputs, targets in iterator:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        # Forward pass with optional AMP autocast
        with torch.autocast(device_type=amp_device_type, enabled=amp_enabled):
            # Apply MixUp if enabled
            if mixup_fn:
                inputs, y_a, y_b, lam = mixup_fn(inputs, targets)
                outputs = model(inputs)
                loss = lam * criterion(outputs, y_a) + (1 - lam) * criterion(outputs, y_b)
            else:
                outputs = model(inputs)
                loss = criterion(outputs, targets)

        # Guard: halt on diverged loss to prevent saving corrupted weights
        if torch.isnan(loss) or torch.isinf(loss):
            raise RuntimeError(
                f"Training diverged: loss={loss.item()} at epoch {epoch}. "
                "Check learning rate, data preprocessing, or enable gradient clipping."
            )

        # Backward pass with optional AMP and gradient clipping
        _backward_step(loss, optimizer, model, scaler, grad_clip)

        # Accumulate loss (extract scalar once to avoid repeated GPU→CPU sync)
        loss_val = loss.item()
        batch_size = inputs.size(0)
        running_loss += loss_val * batch_size
        total_samples += batch_size

        # Update progress bar with current loss
        if use_tqdm:
            iterator.set_postfix({"loss": f"{loss_val:.4f}"})

    # Handle empty training set (defensive guard)
    if total_samples == 0:
        logger.warning("Empty training set: no samples processed. Returning zero loss.")
        return 0.0

    return running_loss / total_samples

validate_epoch(model, val_loader, criterion, device)

Evaluates model performance on held-out validation set.

Computes validation loss, accuracy, and ROC-AUC score under no_grad context. AUC calculated using One-vs-Rest (OvR) strategy with macro-averaging for robust performance estimation on potentially imbalanced datasets.

Parameters:

Name Type Description Default
model Module

Neural network model to evaluate

required
val_loader DataLoader[Any]

Validation data provider

required
criterion Module

Loss function (e.g., CrossEntropyLoss)

required
device device

Hardware target (CUDA/MPS/CPU)

required

Returns:

Type Description
Mapping[str, float]

Validation metrics dict with keys:

Mapping[str, float]
  • loss: Average cross-entropy loss
Mapping[str, float]
  • accuracy: Classification accuracy [0.0, 1.0]
Mapping[str, float]
  • auc: Macro-averaged Area Under the ROC Curve
Mapping[str, float]
  • f1: Macro-averaged F1 score
Source code in orchard/trainer/engine.py
def validate_epoch(
    model: nn.Module,
    val_loader: torch.utils.data.DataLoader[Any],
    criterion: nn.Module,
    device: torch.device,
) -> Mapping[str, float]:
    """
    Evaluates model performance on held-out validation set.

    Computes validation loss, accuracy, and ROC-AUC score under no_grad context.
    AUC calculated using One-vs-Rest (OvR) strategy with macro-averaging for
    robust performance estimation on potentially imbalanced datasets.

    Args:
        model: Neural network model to evaluate
        val_loader: Validation data provider
        criterion: Loss function (e.g., CrossEntropyLoss)
        device: Hardware target (CUDA/MPS/CPU)

    Returns:
        Validation metrics dict with keys:

        - ``loss``: Average cross-entropy loss
        - ``accuracy``: Classification accuracy [0.0, 1.0]
        - ``auc``: Macro-averaged Area Under the ROC Curve
        - ``f1``: Macro-averaged F1 score
    """
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    # Buffers for global metrics (CPU to save VRAM)
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward pass
            outputs = model(inputs)

            # Collect probabilities for AUC (move to CPU to save VRAM)
            probs = torch.softmax(outputs, dim=1)
            all_targets.append(targets.cpu())
            all_probs.append(probs.cpu())

            # Loss computation
            loss = criterion(outputs, targets)
            val_loss += loss.item() * inputs.size(0)

            # Accuracy computation
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    # Handle empty validation set (defensive guard)
    if total == 0 or len(all_targets) == 0:
        logger.warning("Empty validation set: no samples processed. Returning zero metrics.")
        return MappingProxyType(
            {METRIC_LOSS: 0.0, METRIC_ACCURACY: 0.0, METRIC_AUC: 0.0, METRIC_F1: 0.0}
        )

    # Global metric computation
    y_true = torch.cat(all_targets).numpy()
    y_score = torch.cat(all_probs).numpy()
    y_pred = y_score.argmax(axis=1)

    auc = compute_auc(y_true, y_score)
    macro_f1 = float(f1_score(y_true, y_pred, average="macro", zero_division=0.0))

    return MappingProxyType(
        {
            METRIC_LOSS: val_loss / total,
            METRIC_ACCURACY: correct / total,
            METRIC_AUC: auc,
            METRIC_F1: macro_f1,
        }
    )

compute_class_weights(labels, num_classes, device)

Compute balanced class weights (sklearn formula: N / (n_classes * count_c)).

Parameters:

Name Type Description Default
labels NDArray[Any]

Training set labels (1D array).

required
num_classes int

Total number of classes.

required
device device

Target device for the weight tensor.

required

Returns:

Type Description
Tensor

1D tensor of per-class weights, shape (num_classes,).

Source code in orchard/trainer/setup.py
def compute_class_weights(
    labels: npt.NDArray[Any], num_classes: int, device: torch.device
) -> torch.Tensor:
    """
    Compute balanced class weights (sklearn formula: N / (n_classes * count_c)).

    Args:
        labels: Training set labels (1D array).
        num_classes: Total number of classes.
        device: Target device for the weight tensor.

    Returns:
        1D tensor of per-class weights, shape ``(num_classes,)``.
    """
    classes, counts = np.unique(labels, return_counts=True)
    n_total = len(labels)
    weight_map = {int(c): n_total / (num_classes * cnt) for c, cnt in zip(classes, counts)}
    weights = [weight_map.get(i, 1.0) for i in range(num_classes)]
    return torch.tensor(weights, dtype=torch.float).to(device)

get_criterion(training, class_weights=None)

Universal Vision Criterion Factory.

Parameters:

Name Type Description Default
training TrainingConfig

Training sub-config with criterion parameters.

required
class_weights Tensor | None

Optional per-class weights for imbalanced datasets.

None

Returns:

Type Description
Module

Loss module (CrossEntropyLoss or FocalLoss).

Raises:

Type Description
OrchardConfigError

If training.criterion_type is not recognised.

Source code in orchard/trainer/setup.py
def get_criterion(training: TrainingConfig, class_weights: torch.Tensor | None = None) -> nn.Module:
    """
    Universal Vision Criterion Factory.

    Args:
        training: Training sub-config with criterion parameters.
        class_weights: Optional per-class weights for imbalanced datasets.

    Returns:
        Loss module (CrossEntropyLoss or FocalLoss).

    Raises:
        OrchardConfigError: If ``training.criterion_type`` is not recognised.
    """
    c_type = training.criterion_type.lower()
    weights = class_weights if training.weighted_loss else None

    if c_type == "cross_entropy":
        return nn.CrossEntropyLoss(label_smoothing=training.label_smoothing, weight=weights)

    elif c_type == "focal":
        return FocalLoss(gamma=training.focal_gamma, weight=weights)

    else:
        raise OrchardConfigError(f"Unknown criterion type: {c_type}")

get_optimizer(model, training)

Factory function to instantiate optimizer from config.

Dispatches on training.optimizer_type:

  • sgd — SGD with momentum, suited for convolutional architectures.
  • adamw — AdamW with decoupled weight decay, suited for transformers.

Parameters:

Name Type Description Default
model Module

Network whose parameters will be optimised.

required
training TrainingConfig

Training sub-config with optimizer hyper-parameters.

required

Returns:

Type Description
Optimizer

Configured optimizer instance.

Raises:

Type Description
OrchardConfigError

If training.optimizer_type is not recognised.

Source code in orchard/trainer/setup.py
def get_optimizer(model: nn.Module, training: TrainingConfig) -> optim.Optimizer:
    """
    Factory function to instantiate optimizer from config.

    Dispatches on ``training.optimizer_type``:

    - **sgd** — SGD with momentum, suited for convolutional architectures.
    - **adamw** — AdamW with decoupled weight decay, suited for transformers.

    Args:
        model: Network whose parameters will be optimised.
        training: Training sub-config with optimizer hyper-parameters.

    Returns:
        Configured optimizer instance.

    Raises:
        OrchardConfigError: If ``training.optimizer_type`` is not recognised.
    """
    opt_type = training.optimizer_type.lower()

    if opt_type == "sgd":
        return optim.SGD(
            model.parameters(),
            lr=training.learning_rate,
            momentum=training.momentum,
            weight_decay=training.weight_decay,
        )

    elif opt_type == "adamw":
        return optim.AdamW(
            model.parameters(),
            lr=training.learning_rate,
            weight_decay=training.weight_decay,
        )

    else:
        raise OrchardConfigError(
            f"Unknown optimizer type: '{opt_type}'. Available options: ['sgd', 'adamw']"
        )

get_scheduler(optimizer, training)

Advanced Scheduler Factory.

Supports multiple LR decay strategies based on TrainingConfig:

  • cosine — Smooth decay following a cosine curve.
  • plateau — Reduces LR when monitor_metric stops improving (mode="max").
  • step — Periodic reduction by a fixed factor.
  • none — Maintains a constant learning rate.

Parameters:

Name Type Description Default
optimizer Optimizer

Optimizer whose learning rate will be scheduled.

required
training TrainingConfig

Training sub-config with scheduler hyper-parameters.

required

Returns:

Type Description
CosineAnnealingLR | ReduceLROnPlateau | StepLR | LambdaLR

Configured learning rate scheduler instance.

Raises:

Type Description
OrchardConfigError

If training.scheduler_type is not recognised.

Source code in orchard/trainer/setup.py
def get_scheduler(
    optimizer: optim.Optimizer, training: TrainingConfig
) -> (
    lr_scheduler.CosineAnnealingLR
    | lr_scheduler.ReduceLROnPlateau
    | lr_scheduler.StepLR
    | lr_scheduler.LambdaLR
):
    """
    Advanced Scheduler Factory.

    Supports multiple LR decay strategies based on TrainingConfig:

    - **cosine** — Smooth decay following a cosine curve.
    - **plateau** — Reduces LR when ``monitor_metric`` stops improving (``mode="max"``).
    - **step** — Periodic reduction by a fixed factor.
    - **none** — Maintains a constant learning rate.

    Args:
        optimizer: Optimizer whose learning rate will be scheduled.
        training: Training sub-config with scheduler hyper-parameters.

    Returns:
        Configured learning rate scheduler instance.

    Raises:
        OrchardConfigError: If ``training.scheduler_type`` is not recognised.
    """
    sched_type = training.scheduler_type.lower()

    if sched_type == "cosine":
        return lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=training.epochs, eta_min=training.min_lr
        )

    elif sched_type == "plateau":
        # monitor_metric is Literal["auc", "accuracy", "f1"] — all maximize
        return lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode="max",
            factor=training.scheduler_factor,
            patience=training.scheduler_patience,
            min_lr=training.min_lr,
        )

    elif sched_type == "step":
        return lr_scheduler.StepLR(
            optimizer, step_size=training.step_size, gamma=training.scheduler_factor
        )

    elif sched_type == "none":
        # Returns a dummy scheduler that keeps LR constant
        return lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda _epoch: 1.0)

    else:
        raise OrchardConfigError(
            f"Unsupported scheduler_type: '{sched_type}'. "
            "Available options: ['cosine', 'plateau', 'step', 'none']"
        )