Skip to content

criterion_adapter

orchard.tasks.classification.criterion_adapter

Classification Criterion Adapter.

Wraps :func:orchard.trainer.setup.get_criterion to satisfy :class:~orchard.core.task_protocols.TaskCriterionFactory.

ClassificationCriterionAdapter

Builds classification loss functions (CrossEntropy / Focal).

get_criterion(training, class_weights=None)

Delegate to the existing criterion factory.

Parameters:

Name Type Description Default
training TrainingConfig

Training sub-config with criterion parameters.

required
class_weights Tensor | None

Optional per-class weights for imbalanced datasets.

None

Returns:

Type Description
Module

Loss module (CrossEntropyLoss or FocalLoss).

Source code in orchard/tasks/classification/criterion_adapter.py
def get_criterion(
    self,
    training: TrainingConfig,
    class_weights: torch.Tensor | None = None,
) -> nn.Module:
    """
    Delegate to the existing criterion factory.

    Args:
        training: Training sub-config with criterion parameters.
        class_weights: Optional per-class weights for imbalanced datasets.

    Returns:
        Loss module (CrossEntropyLoss or FocalLoss).
    """
    return get_criterion(training, class_weights=class_weights)