training_executor
orchard.optimization.objective.training_executor
¶
Training execution utilities for Optuna trials.
Provides TrialTrainingExecutor, which orchestrates the training and validation
loop for a single Optuna trial with built-in pruning, metric tracking, and
scheduler management. Per-epoch training is delegated to _loop.TrainingLoop
(shared with ModelTrainer), while validation remains local with error-resilient
fallback metrics.
Key responsibilities:
- Execute epoch-level training/validation cycles
- Apply Optuna pruning logic with warmup period
- Track and report metrics to Optuna
- Handle scheduler stepping (plateau-aware)
- Provide error-resilient validation with fallback metrics
.. todo::
Unify TrialTrainingExecutor and ModelTrainer into a single
engine with pluggable epoch-end callbacks (early stopping,
checkpointing, Optuna pruning). Both already share the full
training kernel (TrainingLoop, validate_epoch,
step_scheduler, AMP scaler, Mixup); the only divergence is
the epoch-level loop and post-validation actions.
TrialTrainingExecutor(model, train_loader, val_loader, optimizer, scheduler, criterion, training, optuna, log_interval, device, metric_extractor)
¶
Executes training loop with Optuna pruning integration.
Orchestrates a complete training cycle for a single Optuna trial, including:
- Training and validation epochs
- Metric extraction and tracking
- Pruning decisions with warmup period
- Learning rate scheduling
- Progress logging
Pruning and warmup parameters are read from the optuna sub-config;
training hyperparameters from training.
Attributes:
| Name | Type | Description |
|---|---|---|
model |
PyTorch model to train. |
|
train_loader |
Training data loader. |
|
val_loader |
Validation data loader. |
|
optimizer |
Optimizer instance. |
|
scheduler |
Learning rate scheduler. |
|
criterion |
Loss function. |
|
device |
Training device (CPU/CUDA/MPS). |
|
metric_extractor |
Handles metric extraction and best-value tracking. |
|
enable_pruning |
Whether to enable trial pruning. |
|
warmup_epochs |
Epochs before pruning activates. |
|
monitor_metric |
Name of the metric driving scheduling. |
|
scaler |
GradScaler | None
|
AMP gradient scaler (None when use_amp is False). |
mixup_fn |
callable | None
|
Mixup augmentation function (None when alpha is 0). |
epochs |
Total training epochs. |
|
log_interval |
Epoch interval for progress logging. |
|
_loop |
TrainingLoop
|
Shared epoch kernel for training steps (train only, no validation). |
Example
executor = TrialTrainingExecutor( ... model=model, ... train_loader=train_loader, ... val_loader=val_loader, ... optimizer=optimizer, ... scheduler=scheduler, ... criterion=criterion, ... training=trial_cfg.training, ... optuna=trial_cfg.optuna, ... log_interval=trial_cfg.telemetry.log_interval, ... device=device, ... metric_extractor=MetricExtractor("auc"), ... ) best_metric = executor.execute(trial)
Initialize training executor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
PyTorch model to train. |
required |
train_loader
|
DataLoader[Any]
|
Training data loader. |
required |
val_loader
|
DataLoader[Any]
|
Validation data loader. |
required |
optimizer
|
Optimizer
|
Optimizer instance. |
required |
scheduler
|
LRScheduler
|
Learning rate scheduler. |
required |
criterion
|
Module
|
Loss function. |
required |
training
|
TrainingConfig
|
Training hyperparameters sub-config. |
required |
optuna
|
OptunaConfig
|
Optuna pruning/warmup sub-config. |
required |
log_interval
|
int
|
Epoch interval for progress logging. |
required |
device
|
device
|
Training device. |
required |
metric_extractor
|
MetricExtractor
|
Metric extraction and tracking handler. |
required |
Source code in orchard/optimization/objective/training_executor.py
execute(trial)
¶
Execute full training loop with pruning.
Runs training for cfg.training.epochs, reporting metrics to Optuna after each epoch. Applies pruning logic after warmup period.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trial
|
Trial
|
Optuna trial for reporting and pruning |
required |
Returns:
| Type | Description |
|---|---|
float
|
Best validation metric achieved during training |
Raises:
| Type | Description |
|---|---|
TrialPruned
|
If trial should terminate early |