training_executor
orchard.optimization.objective.training_executor
¶
Training execution utilities for Optuna trials.
Provides TrialTrainingExecutor, which orchestrates the training and validation
loop for a single Optuna trial with built-in pruning, metric tracking, and
scheduler management. Per-epoch training is delegated to _loop.TrainingLoop
(shared with ModelTrainer), while validation remains local with error-resilient
fallback metrics.
Key responsibilities:
- Execute epoch-level training/validation cycles
- Apply Optuna pruning logic with warmup period
- Track and report metrics to Optuna
- Handle scheduler stepping (plateau-aware)
- Provide error-resilient validation with fallback metrics
.. todo::
Unify TrialTrainingExecutor and ModelTrainer into a single
engine with pluggable epoch-end callbacks (early stopping,
checkpointing, Optuna pruning). Both already share the full
training kernel (TrainingLoop, validate_epoch,
step_scheduler, AMP scaler, Mixup); the only divergence is
the epoch-level loop and post-validation actions.
TaskAdapters(training_step=None, validation_metrics=None, fallback_metrics=None)
dataclass
¶
Bundle of task-specific adapters injected into the training executor.
Attributes:
| Name | Type | Description |
|---|---|---|
training_step |
TaskTrainingStep | None
|
Custom forward pass adapter. |
validation_metrics |
TaskValidationMetrics | None
|
Custom validation metrics adapter. |
fallback_metrics |
Mapping[str, float] | None
|
Metrics returned on validation failure. |
TrialTrainingExecutor(model, train_loader, val_loader, optimizer, scheduler, criterion, training, optuna, log_interval, device, metric_extractor, task_adapters=None)
¶
Executes training loop with Optuna pruning integration.
Orchestrates a complete training cycle for a single Optuna trial, including:
- Training and validation epochs
- Metric extraction and tracking
- Pruning decisions with warmup period
- Learning rate scheduling
- Progress logging
Pruning and warmup parameters are read from the optuna sub-config;
training hyperparameters from training.
Attributes:
| Name | Type | Description |
|---|---|---|
model |
PyTorch model to train. |
|
train_loader |
Training data loader. |
|
val_loader |
Validation data loader. |
|
optimizer |
Optimizer instance. |
|
scheduler |
Learning rate scheduler. |
|
criterion |
Loss function. |
|
device |
Training device (CPU/CUDA/MPS). |
|
metric_extractor |
Handles metric extraction and best-value tracking. |
|
enable_pruning |
Whether to enable trial pruning. |
|
warmup_epochs |
Epochs before pruning activates. |
|
monitor_metric |
Name of the metric driving scheduling. |
|
scaler |
GradScaler | None
|
AMP gradient scaler (None when use_amp is False). |
mixup_fn |
callable | None
|
Mixup augmentation function (None when alpha is 0). |
epochs |
Total training epochs. |
|
log_interval |
Epoch interval for progress logging. |
|
_loop |
TrainingLoop
|
Shared epoch kernel for training steps (train only, no validation). |
Example
executor = TrialTrainingExecutor( ... model=model, ... train_loader=train_loader, ... val_loader=val_loader, ... optimizer=optimizer, ... scheduler=scheduler, ... criterion=criterion, ... training=trial_cfg.training, ... optuna=trial_cfg.optuna, ... log_interval=trial_cfg.telemetry.log_interval, ... device=device, ... metric_extractor=MetricExtractor("auc"), ... task_adapters=TaskAdapters(training_step=task.training_step), ... ) best_metric = executor.execute(trial)
Initialize training executor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
PyTorch model to train. |
required |
train_loader
|
DataLoader[Any]
|
Training data loader. |
required |
val_loader
|
DataLoader[Any]
|
Validation data loader. |
required |
optimizer
|
Optimizer
|
Optimizer instance. |
required |
scheduler
|
LRScheduler
|
Learning rate scheduler. |
required |
criterion
|
Module
|
Loss function. |
required |
training
|
TrainingConfig
|
Training hyperparameters sub-config. |
required |
optuna
|
OptunaConfig
|
Optuna pruning/warmup sub-config. |
required |
log_interval
|
int
|
Epoch interval for progress logging. |
required |
device
|
device
|
Training device. |
required |
metric_extractor
|
MetricExtractor
|
Metric extraction and tracking handler. |
required |
task_adapters
|
TaskAdapters | None
|
Task-specific adapters bundle. Contains training step, validation metrics, and fallback metrics. |
None
|
Source code in orchard/optimization/objective/training_executor.py
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | |
execute(trial)
¶
Execute full training loop with pruning.
Runs training for cfg.training.epochs, reporting metrics to Optuna after each epoch. Applies pruning logic after warmup period.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trial
|
Trial
|
Optuna trial for reporting and pruning |
required |
Returns:
| Type | Description |
|---|---|
float
|
Best validation metric achieved during training |
Raises:
| Type | Description |
|---|---|
TrialPruned
|
If trial should terminate early |