Skip to content

export

orchard.export

Model Export Package.

Provides utilities for exporting trained PyTorch models to ONNX format with validation, benchmarking, and optimization support.

Example

from orchard.export import export_to_onnx export_to_onnx( ... model=trained_model, ... checkpoint_path="outputs/best_model.pth", ... output_path="exports/model.onnx", ... input_shape=(3, 224, 224), ... )

benchmark_onnx_inference(onnx_path, input_shape, num_runs=100, seed=42, label='ONNX')

Benchmark ONNX model inference speed.

Parameters:

Name Type Description Default
onnx_path Path

Path to ONNX model

required
input_shape tuple[int, int, int]

Input tensor shape (C, H, W)

required
num_runs int

Number of inference runs for averaging

100
seed int

Random seed for reproducible dummy input

42
label str

Display label for the benchmark log header

'ONNX'

Returns:

Type Description
float

Average inference time in milliseconds

Example

latency = benchmark_onnx_inference(Path("model.onnx")) print(f"Latency: {latency:.2f}ms")

Source code in orchard/export/onnx_exporter.py
def benchmark_onnx_inference(
    onnx_path: Path,
    input_shape: tuple[int, int, int],
    num_runs: int = 100,
    seed: int = 42,
    label: str = "ONNX",
) -> float:
    """
    Benchmark ONNX model inference speed.

    Args:
        onnx_path: Path to ONNX model
        input_shape: Input tensor shape (C, H, W)
        num_runs: Number of inference runs for averaging
        seed: Random seed for reproducible dummy input
        label: Display label for the benchmark log header

    Returns:
        Average inference time in milliseconds

    Example:
        >>> latency = benchmark_onnx_inference(Path("model.onnx"))
        >>> print(f"Latency: {latency:.2f}ms")
    """
    try:
        import time

        import numpy as np
        import onnxruntime as ort

        logger.info("  [Benchmark — %s]", label)

        # Create inference session
        session = ort.InferenceSession(str(onnx_path))

        # Prepare dummy input using N(0,1) (matches validation distribution)
        rng = np.random.default_rng(seed)
        dummy_input = rng.standard_normal(size=(1, *input_shape)).astype(np.float32)

        # Warmup
        for _ in range(10):
            session.run(None, {"input": dummy_input})

        # Benchmark
        start = time.time()
        for _ in range(num_runs):
            session.run(None, {"input": dummy_input})
        elapsed = time.time() - start

        avg_latency_ms = (elapsed / num_runs) * 1000
        logger.info("    %s Runs              : %s", LogStyle.BULLET, num_runs)
        logger.info("    %s Avg latency       : %.2fms", LogStyle.BULLET, avg_latency_ms)
        logger.info("")

        return avg_latency_ms

    except ImportError:
        logger.warning("onnxruntime not installed. Skipping benchmark.")
        return -1.0
    except Exception as e:  # onnxruntime raises non-standard exceptions
        logger.error("Benchmark failed: %s", e)
        return -1.0

export_to_onnx(model, checkpoint_path, output_path, input_shape, opset_version=18, dynamic_axes=True, do_constant_folding=True, validate=True)

Export trained PyTorch model to ONNX format.

Parameters:

Name Type Description Default
model Module

PyTorch model architecture (uninitialized weights OK)

required
checkpoint_path Path

Path to trained .pth checkpoint

required
output_path Path

Output path for .onnx file

required
input_shape tuple[int, int, int]

Input tensor shape (C, H, W)

required
opset_version int

ONNX opset version (default: 18)

18
dynamic_axes bool

Enable dynamic batch size (required for production)

True
do_constant_folding bool

Optimize constant operations at export

True
validate bool

Validate exported model with ONNX checker

True

Raises:

Type Description
FileNotFoundError

If checkpoint_path does not exist.

RuntimeError

If state_dict loading fails (architecture mismatch).

ValueError

If ONNX validation fails (when validate=True).

Example

export_to_onnx( ... model=EfficientNet(), ... checkpoint_path=Path("outputs/best_model.pth"), ... output_path=Path("exports/model.onnx"), ... input_shape=(3, 224, 224), ... )

Source code in orchard/export/onnx_exporter.py
def export_to_onnx(
    model: nn.Module,
    checkpoint_path: Path,
    output_path: Path,
    input_shape: tuple[int, int, int],
    opset_version: int = 18,
    dynamic_axes: bool = True,
    do_constant_folding: bool = True,
    validate: bool = True,
) -> None:
    """
    Export trained PyTorch model to ONNX format.

    Args:
        model: PyTorch model architecture (uninitialized weights OK)
        checkpoint_path: Path to trained .pth checkpoint
        output_path: Output path for .onnx file
        input_shape: Input tensor shape (C, H, W)
        opset_version: ONNX opset version (default: 18)
        dynamic_axes: Enable dynamic batch size (required for production)
        do_constant_folding: Optimize constant operations at export
        validate: Validate exported model with ONNX checker

    Raises:
        FileNotFoundError: If checkpoint_path does not exist.
        RuntimeError: If state_dict loading fails (architecture mismatch).
        ValueError: If ONNX validation fails (when validate=True).

    Example:
        >>> export_to_onnx(
        ...     model=EfficientNet(),
        ...     checkpoint_path=Path("outputs/best_model.pth"),
        ...     output_path=Path("exports/model.onnx"),
        ...     input_shape=(3, 224, 224),
        ... )
    """
    logger.info("  [Source]")
    logger.info("    %s Checkpoint        : %s", LogStyle.BULLET, checkpoint_path.name)
    logger.info("")

    # Create output directory if needed
    output_path.parent.mkdir(parents=True, exist_ok=True)

    # Load trained weights
    checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
    if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model.load_state_dict(checkpoint)

    # Move model to CPU for ONNX export
    model.cpu()
    model.eval()

    # Create dummy input (batch_size=1 for tracing)
    dummy_input = torch.randn(1, *input_shape)

    logger.info("  [Export Settings]")
    logger.info("    %s Format            : ONNX (opset %s)", LogStyle.BULLET, opset_version)
    logger.info("    %s Input shape       : %s", LogStyle.BULLET, tuple(dummy_input.shape))
    logger.info("    %s Dynamic axes      : %s", LogStyle.BULLET, dynamic_axes)
    logger.info("")

    # Prepare dynamic axes configuration
    if dynamic_axes:
        dynamic_axes_config = {
            "input": {0: "batch_size"},
            "output": {0: "batch_size"},
        }
    else:
        dynamic_axes_config = None

    # Export to ONNX (suppress verbose PyTorch warnings for cleaner output)
    # Temporarily suppress torch.onnx internal loggers to avoid roi_align warnings
    onnx_loggers = [
        logging.getLogger("torch.onnx._internal.exporter._schemas"),
        logging.getLogger("torch.onnx._internal.exporter"),
    ]
    original_levels = [log.level for log in onnx_loggers]

    try:
        # Raise log level to ERROR to suppress WARNING messages
        for onnx_logger in onnx_loggers:
            onnx_logger.setLevel(logging.ERROR)

        with (
            warnings.catch_warnings(),
            contextlib.redirect_stdout(io.StringIO()),
            contextlib.redirect_stderr(io.StringIO()),
        ):
            # Suppress warnings and stdout prints (e.g. ONNX rewrite rules)
            warnings.simplefilter("ignore")

            torch.onnx.export(
                model,
                (dummy_input,),  # Wrap in tuple for mypy type checking
                str(output_path),
                export_params=True,
                opset_version=opset_version,
                do_constant_folding=do_constant_folding,
                input_names=["input"],
                output_names=["output"],
                dynamic_axes=dynamic_axes_config,
                verbose=False,
            )
    finally:
        # Restore original log levels
        for onnx_logger, original_level in zip(onnx_loggers, original_levels):
            onnx_logger.setLevel(original_level)

    # Validate exported model
    if validate:
        try:
            import onnx

            onnx_model = onnx.load(str(output_path))
            onnx.checker.check_model(onnx_model)

            logger.info("  [Validation]")
            logger.info("    %s ONNX check        : %s Valid", LogStyle.BULLET, LogStyle.SUCCESS)
            size_mb = _onnx_file_size_mb(output_path)
            logger.info("    %s Model size        : %.2f MB", LogStyle.BULLET, size_mb)

        except ImportError:
            logger.warning(
                "    %s onnx package not installed. Skipping validation.", LogStyle.WARNING
            )
        except (ValueError, RuntimeError) as e:
            logger.error("    %s ONNX validation failed: %s", LogStyle.FAILURE, e)
            if output_path.exists():
                output_path.unlink()
                logger.info("    %s Cleaned up invalid ONNX file", LogStyle.ARROW)
            raise

    logger.info("")

quantize_model(onnx_path, output_path=None, backend='qnnpack', weight_type='int8')

Apply dynamic post-training quantization to an ONNX model.

Dispatches to 8-bit or 4-bit quantize_dynamic based on weight_type. INT4/UINT4 quantize only Gemm/MatMul nodes (Linear layers), leaving Conv layers at full precision.

Parameters:

Name Type Description Default
onnx_path Path

Path to the exported ONNX model

required
output_path Path | None

Path for quantized model (defaults to model_quantized.onnx in the same directory)

None
backend str

Quantization backend ("qnnpack" for mobile/ARM, "fbgemm" for x86)

'qnnpack'
weight_type str

Weight quantization type — "int8", "uint8", "int4", or "uint4"

'int8'

Returns:

Type Description
Path | None

Path to the quantized ONNX model, or None if quantization failed

Example

quantized = quantize_model(Path("exports/model.onnx")) print(f"Quantized model: {quantized}")

Source code in orchard/export/onnx_exporter.py
def quantize_model(
    onnx_path: Path,
    output_path: Path | None = None,
    backend: str = "qnnpack",
    weight_type: str = "int8",
) -> Path | None:
    """
    Apply dynamic post-training quantization to an ONNX model.

    Dispatches to 8-bit or 4-bit ``quantize_dynamic`` based on
    *weight_type*.  INT4/UINT4 quantize only Gemm/MatMul nodes (Linear
    layers), leaving Conv layers at full precision.

    Args:
        onnx_path: Path to the exported ONNX model
        output_path: Path for quantized model (defaults to model_quantized.onnx
                     in the same directory)
        backend: Quantization backend ("qnnpack" for mobile/ARM, "fbgemm" for x86)
        weight_type: Weight quantization type — "int8", "uint8", "int4", or "uint4"

    Returns:
        Path to the quantized ONNX model, or None if quantization failed

    Example:
        >>> quantized = quantize_model(Path("exports/model.onnx"))
        >>> print(f"Quantized model: {quantized}")
    """
    if output_path is None:
        output_path = onnx_path.parent / "model_quantized.onnx"

    logger.info("  [Quantization]")
    logger.info("    %s Backend           : %s", LogStyle.BULLET, backend)
    logger.info("    %s Weight type       : %s", LogStyle.BULLET, weight_type)

    try:
        if weight_type in ("int4", "uint4"):
            _quantize_4bit(onnx_path, output_path, weight_type)
        else:
            _quantize_8bit(onnx_path, output_path, backend, weight_type)
    except ImportError:
        logger.warning(
            "    %s onnxruntime.quantization not available. Skipping quantization.",
            LogStyle.WARNING,
        )
        return None
    except Exception as e:  # onnxruntime raises non-standard exceptions
        logger.error("    %s Quantization failed: %s", LogStyle.FAILURE, e)
        if output_path.exists():
            output_path.unlink()
        return None

    original_mb = _onnx_file_size_mb(onnx_path)
    quantized_mb = _onnx_file_size_mb(output_path)
    ratio = original_mb / quantized_mb if quantized_mb > 0 else 0

    logger.info(
        "    %s Size              : %.2f MB → %.2f MB (%.1fx)",
        LogStyle.BULLET,
        original_mb,
        quantized_mb,
        ratio,
    )
    logger.info("    %s Status            : %s Done", LogStyle.BULLET, LogStyle.SUCCESS)
    logger.info("")

    return output_path

validate_export(pytorch_model, onnx_path, input_shape, num_samples=10, max_deviation=0.0001, label='ONNX')

Validate ONNX export against PyTorch model.

Compares outputs from PyTorch and ONNX models on random inputs to ensure numerical consistency after export.

Parameters:

Name Type Description Default
pytorch_model Module

Original PyTorch model (with loaded weights)

required
onnx_path Path

Path to exported ONNX model

required
input_shape tuple[int, int, int]

Input tensor shape (C, H, W)

required
num_samples int

Number of random samples to test

10
max_deviation float

Maximum allowed absolute difference

0.0001
label str

Display label for log header (e.g. "ONNX", "Quantized")

'ONNX'

Returns:

Type Description
bool | None

True if validation passes, False if outputs diverge,

bool | None

None if skipped (onnxruntime not installed).

Raises:

Type Description
OrchardExportError

If the ONNX file does not exist.

Example

model.load_state_dict(torch.load("checkpoint.pth")) valid = validate_export(model, Path("model.onnx"), input_shape=(3, 224, 224)) if valid: ... print("Export validated successfully!")

Source code in orchard/export/validation.py
def validate_export(
    pytorch_model: nn.Module,
    onnx_path: Path,
    input_shape: tuple[int, int, int],
    num_samples: int = 10,
    max_deviation: float = 1e-4,
    label: str = "ONNX",
) -> bool | None:
    """
    Validate ONNX export against PyTorch model.

    Compares outputs from PyTorch and ONNX models on random inputs
    to ensure numerical consistency after export.

    Args:
        pytorch_model: Original PyTorch model (with loaded weights)
        onnx_path: Path to exported ONNX model
        input_shape: Input tensor shape (C, H, W)
        num_samples: Number of random samples to test
        max_deviation: Maximum allowed absolute difference
        label: Display label for log header (e.g. "ONNX", "Quantized")

    Returns:
        True if validation passes, False if outputs diverge,
        None if skipped (onnxruntime not installed).

    Raises:
        OrchardExportError: If the ONNX file does not exist.

    Example:
        >>> model.load_state_dict(torch.load("checkpoint.pth"))
        >>> valid = validate_export(model, Path("model.onnx"), input_shape=(3, 224, 224))
        >>> if valid:
        ...     print("Export validated successfully!")
    """
    # Check if ONNX file exists
    if not onnx_path.exists():
        raise OrchardExportError(f"ONNX model not found: {onnx_path}")

    try:
        import onnxruntime as ort

        logger.info("%s[%s Validation]", LogStyle.INDENT, label)
        logger.info("    %s Samples           : %s", LogStyle.BULLET, num_samples)
        logger.info("    %s Max deviation     : %.0e", LogStyle.BULLET, max_deviation)

        # Load ONNX model (force CPU to match export conditions)
        session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])

        pytorch_model.eval()
        pytorch_model.cpu()

        max_diff = 0.0
        g = torch.Generator(device="cpu")
        g.manual_seed(0)

        with torch.no_grad():
            for i in range(num_samples):
                # Generate deterministic random input for reproducible validation
                x_torch = torch.randn(1, *input_shape, generator=g)
                x_numpy = x_torch.numpy().astype(np.float32)

                # PyTorch inference
                y_torch = pytorch_model(x_torch).numpy()

                # ONNX inference
                y_onnx = session.run(None, {"input": x_numpy})[0]

                # Shape guard: detect export shape mismatches early
                if y_torch.shape != y_onnx.shape:
                    raise OrchardExportError(
                        f"Output shape mismatch: PyTorch {y_torch.shape} vs ONNX {y_onnx.shape}"
                    )

                # Compare outputs
                diff = np.abs(y_torch - y_onnx).max()
                max_diff = max(max_diff, diff)

                if diff > max_deviation:
                    logger.error(
                        "    %s Result            : %s FAILED sample %d (diff: %.2e, threshold: %.2e)",
                        LogStyle.BULLET,
                        LogStyle.WARNING,
                        i + 1,
                        diff,
                        max_deviation,
                    )
                    return False

        logger.info(
            "    %s Result            : %s Passed (max diff: %.2e)",
            LogStyle.BULLET,
            LogStyle.SUCCESS,
            max_diff,
        )
        logger.info("")
        return True

    except ImportError as e:
        logger.warning("onnxruntime not installed. Skipping validation: %s", e)
        return None
    except (RuntimeError, ValueError) as e:
        logger.error("Validation failed: %s", e, exc_info=True)
        raise