Skip to content

objective

orchard.optimization.objective

Optuna objective components for the training pipeline.

This package provides the Optuna objective function and its supporting components, structured around single-responsibility modules for configuration building, metric extraction, and training execution.

TrialConfigBuilder(base_cfg)

Builds trial-specific Config instances for Optuna trials.

Handles parameter mapping from Optuna's flat namespace to Config's hierarchical structure, preserves dataset metadata excluded from serialization, and validates via Pydantic.

Attributes:

Name Type Description
base_cfg

Base configuration template

optuna_epochs

Number of epochs for Optuna trials (from cfg.optuna.epochs)

base_metadata

Cached dataset metadata

Example

builder = TrialConfigBuilder(base_cfg) trial_params = {"learning_rate": 0.001, "dropout": 0.3} trial_cfg = builder.build(trial_params)

Initialize config builder.

Parameters:

Name Type Description Default
base_cfg Config

Base configuration template

required
Source code in orchard/optimization/objective/config_builder.py
def __init__(self, base_cfg: Config) -> None:
    """
    Initialize config builder.

    Args:
        base_cfg: Base configuration template
    """
    self.base_cfg = base_cfg
    self.optuna_epochs = base_cfg.optuna.epochs
    self.base_metadata = base_cfg.dataset._ensure_metadata

build(trial_params)

Build trial-specific Config with parameter overrides.

Parameters:

Name Type Description Default
trial_params dict[str, Any]

Sampled hyperparameters from Optuna

required

Returns:

Type Description
Config

Validated Config instance with trial parameters

Source code in orchard/optimization/objective/config_builder.py
def build(self, trial_params: dict[str, Any]) -> Config:
    """
    Build trial-specific Config with parameter overrides.

    Args:
        trial_params: Sampled hyperparameters from Optuna

    Returns:
        Validated Config instance with trial parameters
    """
    config_dict = self.base_cfg.model_dump()

    # Preserve resolution
    if config_dict["dataset"].get("resolution") is None:
        config_dict["dataset"]["resolution"] = self.base_cfg.dataset.resolution

    # Re-inject metadata (excluded from serialization)
    config_dict["dataset"]["metadata"] = self.base_metadata

    # Override epochs for Optuna trials
    config_dict["training"]["epochs"] = self.optuna_epochs

    # Cap mixup_epochs to trial length (prevents _check_mixup_epochs ValueError)
    config_dict["training"]["mixup_epochs"] = min(
        config_dict["training"]["mixup_epochs"], self.optuna_epochs
    )

    # Apply trial-specific overrides
    self._apply_param_overrides(config_dict, trial_params)

    return Config(**config_dict)

MetricExtractor(metric_name, direction='maximize')

Extracts and tracks metrics from validation results.

Handles metric extraction with validation and maintains the best metric value achieved during training. Direction-aware: uses max() for maximize objectives, min() for minimize.

Attributes:

Name Type Description
metric_name

Name of metric to track (e.g., 'auc', 'accuracy')

direction

Optimization direction ('maximize' or 'minimize')

best_metric

Best metric value achieved so far

Example

extractor = MetricExtractor("auc", direction="maximize") val_metrics = {"loss": 0.5, "accuracy": 0.85, "auc": 0.92} current = extractor.extract(val_metrics) # 0.92 best = extractor.update_best(current) # 0.92

Initialize metric extractor.

Parameters:

Name Type Description Default
metric_name str

Name of metric to track

required
direction str

'maximize' or 'minimize'

'maximize'
Source code in orchard/optimization/objective/metric_extractor.py
def __init__(self, metric_name: str, direction: str = "maximize") -> None:
    """
    Initialize metric extractor.

    Args:
        metric_name: Name of metric to track
        direction: 'maximize' or 'minimize'
    """
    self.metric_name = metric_name
    self.direction = direction
    self._is_maximize = direction == "maximize"
    self.best_metric = -float("inf") if self._is_maximize else float("inf")

extract(val_metrics)

Extract target metric from validation results.

Parameters:

Name Type Description Default
val_metrics Mapping[str, float]

Dictionary of validation metrics

required

Returns:

Type Description
float

Value of target metric

Raises:

Type Description
KeyError

If metric_name not found in val_metrics

Source code in orchard/optimization/objective/metric_extractor.py
def extract(self, val_metrics: Mapping[str, float]) -> float:
    """
    Extract target metric from validation results.

    Args:
        val_metrics: Dictionary of validation metrics

    Returns:
        Value of target metric

    Raises:
        KeyError: If metric_name not found in val_metrics
    """
    if self.metric_name not in val_metrics:
        available = list(val_metrics.keys())
        raise KeyError(f"Metric '{self.metric_name}' not found. Available: {available}")
    return val_metrics[self.metric_name]

reset()

Reset best metric tracking for a new trial.

Source code in orchard/optimization/objective/metric_extractor.py
def reset(self) -> None:
    """Reset best metric tracking for a new trial."""
    self.best_metric = -float("inf") if self._is_maximize else float("inf")

update_best(current_metric)

Update and return best metric achieved within current trial.

Direction-aware: uses max() for maximize, min() for minimize. NaN values are ignored to prevent poisoning the best-metric state (max(-inf, NaN) returns NaN in Python, which would permanently corrupt comparisons).

Parameters:

Name Type Description Default
current_metric float

Current metric value

required

Returns:

Type Description
float

Best metric value achieved so far

Source code in orchard/optimization/objective/metric_extractor.py
def update_best(self, current_metric: float) -> float:
    """
    Update and return best metric achieved within current trial.

    Direction-aware: uses max() for maximize, min() for minimize.
    NaN values are ignored to prevent poisoning the best-metric state
    (``max(-inf, NaN)`` returns NaN in Python, which would permanently
    corrupt comparisons).

    Args:
        current_metric: Current metric value

    Returns:
        Best metric value achieved so far
    """
    if math.isnan(current_metric):
        return self.best_metric

    comparator = max if self._is_maximize else min
    self.best_metric = comparator(self.best_metric, current_metric)

    return self.best_metric

OptunaObjective(cfg, search_space, device, dataset_loader=None, dataloader_factory=None, model_factory=None, tracker=None)

Optuna objective function with dependency injection.

Orchestrates hyperparameter optimization trials by:

  • Building trial-specific configurations
  • Creating data loaders, models, and optimizers
  • Executing training with pruning
  • Tracking and returning best metrics

All external dependencies are injectable for testability:

  • dataset_loader: Dataset loading function
  • dataloader_factory: DataLoader creation function
  • model_factory: Model instantiation function

Attributes:

Name Type Description
cfg

Base configuration (single source of truth)

search_space

Hyperparameter search space

device

Training device (CPU/CUDA/MPS)

config_builder

Builds trial-specific configs

metric_extractor

Handles metric extraction

dataset_data

Cached dataset (loaded once, reused across trials)

Example

objective = OptunaObjective( ... cfg=config, ... search_space=search_space, ... device=torch.device("cuda"), ... ) study = optuna.create_study(direction="maximize") study.optimize(objective, n_trials=50)

Initialize Optuna objective.

Parameters:

Name Type Description Default
cfg Config

Base configuration (reads optuna.* settings)

required
search_space Mapping[str, Any]

Hyperparameter search space

required
device device

Training device

required
dataset_loader DatasetLoaderProtocol | None

Dataset loading function (default: load_dataset)

None
dataloader_factory DataloaderFactoryProtocol | None

DataLoader factory (default: get_dataloaders)

None
model_factory ModelFactoryProtocol | None

Model factory (default: get_model)

None
tracker TrackerProtocol | None

Optional experiment tracker for nested trial logging

None
Source code in orchard/optimization/objective/objective.py
def __init__(
    self,
    cfg: Config,
    search_space: Mapping[str, Any],
    device: torch.device,
    dataset_loader: DatasetLoaderProtocol | None = None,
    dataloader_factory: DataloaderFactoryProtocol | None = None,
    model_factory: ModelFactoryProtocol | None = None,
    tracker: TrackerProtocol | None = None,
) -> None:
    """
    Initialize Optuna objective.

    Args:
        cfg: Base configuration (reads optuna.* settings)
        search_space: Hyperparameter search space
        device: Training device
        dataset_loader: Dataset loading function (default: load_dataset)
        dataloader_factory: DataLoader factory (default: get_dataloaders)
        model_factory: Model factory (default: get_model)
        tracker: Optional experiment tracker for nested trial logging
    """
    self.cfg = cfg
    self.search_space = search_space
    self.device = device
    self.tracker = tracker

    # Dependency injection with defaults
    self._dataset_loader = dataset_loader or load_dataset
    self._dataloader_factory = dataloader_factory or get_dataloaders
    self._model_factory = model_factory or get_model

    # Components (monitor_metric is the single source of truth for the
    # optimisation target — shared by trainer checkpointing and Optuna ranking)
    self.config_builder = TrialConfigBuilder(cfg)
    self.metric_extractor = MetricExtractor(
        cfg.training.monitor_metric, direction=cfg.optuna.direction
    )

    # Load dataset once (reused across all trials)
    self.dataset_data = self._dataset_loader(self.config_builder.base_metadata)

__call__(trial)

Execute single Optuna trial.

Samples hyperparameters, builds trial configuration, trains model, and returns best validation metric. Failed trials return the worst possible metric instead of crashing the study.

Parameters:

Name Type Description Default
trial Trial

Optuna trial object

required

Returns:

Type Description
float

Best validation metric achieved during training,

float

or worst-case metric if the trial fails.

Raises:

Type Description
TrialPruned

If trial is pruned during training

Source code in orchard/optimization/objective/objective.py
def __call__(self, trial: optuna.Trial) -> float:
    """
    Execute single Optuna trial.

    Samples hyperparameters, builds trial configuration, trains model,
    and returns best validation metric. Failed trials return the worst
    possible metric instead of crashing the study.

    Args:
        trial: Optuna trial object

    Returns:
        Best validation metric achieved during training,
        or worst-case metric if the trial fails.

    Raises:
        optuna.TrialPruned: If trial is pruned during training
    """
    # Reset per-trial metric tracking
    self.metric_extractor.reset()

    # Sample parameters
    params = self._sample_params(trial)

    # Build trial config
    trial_cfg = self.config_builder.build(params)

    # Inject recipe-level flags for logging (not Optuna params)
    log_params = {**params, "pretrained": self.cfg.architecture.pretrained}

    # Log trial start
    log_trial_start(trial.number, log_params)

    # Start nested MLflow run for this trial
    if self.tracker is not None:
        self.tracker.start_optuna_trial(trial.number, log_params)

    trial_succeeded = False
    try:
        # Setup training components
        train_loader, val_loader, _ = self._dataloader_factory(
            self.dataset_data,
            trial_cfg.dataset,
            trial_cfg.training,
            trial_cfg.augmentation,
            trial_cfg.num_workers,
            is_optuna=True,
        )
        model = self._model_factory(self.device, trial_cfg.dataset, trial_cfg.architecture)
        optimizer = get_optimizer(model, trial_cfg.training)
        scheduler = get_scheduler(optimizer, trial_cfg.training)

        class_weights = None
        if trial_cfg.training.weighted_loss:
            train_labels = train_loader.dataset.labels.flatten()  # type: ignore[attr-defined]
            num_classes = self.config_builder.base_metadata.num_classes
            class_weights = compute_class_weights(train_labels, num_classes, self.device)

        criterion = get_criterion(trial_cfg.training, class_weights=class_weights)

        # Execute training
        executor = TrialTrainingExecutor(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            optimizer=optimizer,
            scheduler=scheduler,
            criterion=criterion,
            training=trial_cfg.training,
            optuna=trial_cfg.optuna,
            log_interval=trial_cfg.telemetry.log_interval,
            device=self.device,
            metric_extractor=self.metric_extractor,
        )

        best_metric = executor.execute(trial)
        trial_succeeded = True

        return best_metric

    except optuna.TrialPruned:
        trial_succeeded = True  # pruned trials have valid metrics
        raise

    except Exception as e:  # must not crash study
        logger.error(  # pragma: no mutate
            "%s%s Trial %d failed: %s: %s",
            LogStyle.INDENT,  # pragma: no mutate
            LogStyle.FAILURE,  # pragma: no mutate
            trial.number,  # pragma: no mutate
            type(e).__name__,  # pragma: no mutate
            e,  # pragma: no mutate
        )
        return self._worst_metric()

    finally:
        # End nested MLflow run for this trial
        if self.tracker is not None:
            if trial_succeeded:
                self.tracker.end_optuna_trial(self.metric_extractor.best_metric)
            else:
                # Trial failed before any validation — close run without metric
                self.tracker.end_optuna_trial(self._worst_metric())

        # Cleanup GPU memory between trials
        self._cleanup()

TrialTrainingExecutor(model, train_loader, val_loader, optimizer, scheduler, criterion, training, optuna, log_interval, device, metric_extractor)

Executes training loop with Optuna pruning integration.

Orchestrates a complete training cycle for a single Optuna trial, including:

  • Training and validation epochs
  • Metric extraction and tracking
  • Pruning decisions with warmup period
  • Learning rate scheduling
  • Progress logging

Pruning and warmup parameters are read from the optuna sub-config; training hyperparameters from training.

Attributes:

Name Type Description
model

PyTorch model to train.

train_loader

Training data loader.

val_loader

Validation data loader.

optimizer

Optimizer instance.

scheduler

Learning rate scheduler.

criterion

Loss function.

device

Training device (CPU/CUDA/MPS).

metric_extractor

Handles metric extraction and best-value tracking.

enable_pruning

Whether to enable trial pruning.

warmup_epochs

Epochs before pruning activates.

monitor_metric

Name of the metric driving scheduling.

scaler GradScaler | None

AMP gradient scaler (None when use_amp is False).

mixup_fn callable | None

Mixup augmentation function (None when alpha is 0).

epochs

Total training epochs.

log_interval

Epoch interval for progress logging.

_loop TrainingLoop

Shared epoch kernel for training steps (train only, no validation).

Example

executor = TrialTrainingExecutor( ... model=model, ... train_loader=train_loader, ... val_loader=val_loader, ... optimizer=optimizer, ... scheduler=scheduler, ... criterion=criterion, ... training=trial_cfg.training, ... optuna=trial_cfg.optuna, ... log_interval=trial_cfg.telemetry.log_interval, ... device=device, ... metric_extractor=MetricExtractor("auc"), ... ) best_metric = executor.execute(trial)

Initialize training executor.

Parameters:

Name Type Description Default
model Module

PyTorch model to train.

required
train_loader DataLoader[Any]

Training data loader.

required
val_loader DataLoader[Any]

Validation data loader.

required
optimizer Optimizer

Optimizer instance.

required
scheduler LRScheduler

Learning rate scheduler.

required
criterion Module

Loss function.

required
training TrainingConfig

Training hyperparameters sub-config.

required
optuna OptunaConfig

Optuna pruning/warmup sub-config.

required
log_interval int

Epoch interval for progress logging.

required
device device

Training device.

required
metric_extractor MetricExtractor

Metric extraction and tracking handler.

required
Source code in orchard/optimization/objective/training_executor.py
def __init__(
    self,
    model: torch.nn.Module,
    train_loader: torch.utils.data.DataLoader[Any],
    val_loader: torch.utils.data.DataLoader[Any],
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler.LRScheduler,
    criterion: torch.nn.Module,
    training: TrainingConfig,
    optuna: OptunaConfig,
    log_interval: int,
    device: torch.device,
    metric_extractor: MetricExtractor,
) -> None:
    """
    Initialize training executor.

    Args:
        model: PyTorch model to train.
        train_loader: Training data loader.
        val_loader: Validation data loader.
        optimizer: Optimizer instance.
        scheduler: Learning rate scheduler.
        criterion: Loss function.
        training: Training hyperparameters sub-config.
        optuna: Optuna pruning/warmup sub-config.
        log_interval: Epoch interval for progress logging.
        device: Training device.
        metric_extractor: Metric extraction and tracking handler.
    """
    self.model = model
    self.train_loader = train_loader
    self.val_loader = val_loader
    self.optimizer = optimizer
    self.scheduler = scheduler
    self.criterion = criterion
    self.device = device
    self.metric_extractor = metric_extractor

    # Pruning config
    self.enable_pruning = optuna.enable_pruning
    self.warmup_epochs = optuna.pruning_warmup_epochs

    # Training state
    self.scaler = create_amp_scaler(training, device=str(device))
    self.mixup_fn = create_mixup_fn(training)
    self.epochs = training.epochs
    self.monitor_metric = training.monitor_metric
    self.log_interval = log_interval
    self._consecutive_val_failures: int = 0

    # Shared epoch kernel (train step only — validation is error-resilient here)
    self._loop = TrainingLoop(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        device=device,
        scaler=self.scaler,
        mixup_fn=self.mixup_fn,
        options=LoopOptions(
            grad_clip=training.grad_clip,
            total_epochs=self.epochs,
            mixup_epochs=training.mixup_epochs,
            use_tqdm=False,
            monitor_metric=self.monitor_metric,
        ),
    )

execute(trial)

Execute full training loop with pruning.

Runs training for cfg.training.epochs, reporting metrics to Optuna after each epoch. Applies pruning logic after warmup period.

Parameters:

Name Type Description Default
trial Trial

Optuna trial for reporting and pruning

required

Returns:

Type Description
float

Best validation metric achieved during training

Raises:

Type Description
TrialPruned

If trial should terminate early

Source code in orchard/optimization/objective/training_executor.py
def execute(self, trial: optuna.Trial) -> float:
    """
    Execute full training loop with pruning.

    Runs training for cfg.training.epochs, reporting metrics to Optuna
    after each epoch. Applies pruning logic after warmup period.

    Args:
        trial: Optuna trial for reporting and pruning

    Returns:
        Best validation metric achieved during training

    Raises:
        optuna.TrialPruned: If trial should terminate early
    """
    for epoch in range(1, self.epochs + 1):
        # Train (delegated to shared loop)
        epoch_loss = self._loop.run_train_step(epoch)

        # Validate
        val_metrics = self._validate_epoch()

        # Extract and track metric
        current_metric = self.metric_extractor.extract(val_metrics)
        best_metric = self.metric_extractor.update_best(current_metric)

        # Report to Optuna (skip NaN to avoid poisoning the pruner)
        if not math.isnan(current_metric):
            trial.report(current_metric, epoch)

        # Check pruning
        if self._should_prune(trial, epoch):
            logger.info(
                "%s%s Trial %d pruned at epoch %d (%s=%.4f)",
                LogStyle.INDENT,
                LogStyle.ARROW,
                trial.number,
                epoch,
                self.metric_extractor.metric_name,
                current_metric,
            )
            raise optuna.TrialPruned()

        # Scheduler step (uses monitor_metric, consistent with ModelTrainer)
        step_scheduler(self.scheduler, val_metrics[self.monitor_metric])

        # Logging
        if epoch % self.log_interval == 0 or epoch == self.epochs:
            logger.info(
                "%sT%d E%d/%d | Loss:%.4f | %s:%.4f (Best:%.4f)",
                LogStyle.DOUBLE_INDENT,
                trial.number,
                epoch,
                self.epochs,
                epoch_loss,
                self.metric_extractor.metric_name,
                current_metric,
                best_metric,
            )

    self._log_trial_complete(trial, best_metric, epoch_loss)
    return best_metric