Skip to content

onnx_exporter

orchard.export.onnx_exporter

ONNX Export, Quantization, and Benchmarking.

End-to-end production pipeline for converting trained PyTorch checkpoints to optimized ONNX graphs. The module is consumed by the CLI orchard export command and operates entirely on CPU.

Key Functions:

  • export_to_onnx: Trace-based export with dynamic batch axes, constant folding, and optional onnx.checker validation.
  • quantize_model: Dynamic post-training quantization (INT8, UINT8, INT4, UINT4) via onnxruntime (qnnpack for ARM, fbgemm for x86).
  • benchmark_onnx_inference: Warm-up + timed inference loop returning average latency in milliseconds.

export_to_onnx(model, checkpoint_path, output_path, input_shape, opset_version=18, dynamic_axes=True, do_constant_folding=True, validate=True)

Export trained PyTorch model to ONNX format.

Parameters:

Name Type Description Default
model Module

PyTorch model architecture (uninitialized weights OK)

required
checkpoint_path Path

Path to trained .pth checkpoint

required
output_path Path

Output path for .onnx file

required
input_shape tuple[int, int, int]

Input tensor shape (C, H, W)

required
opset_version int

ONNX opset version (default: 18)

18
dynamic_axes bool

Enable dynamic batch size (required for production)

True
do_constant_folding bool

Optimize constant operations at export

True
validate bool

Validate exported model with ONNX checker

True

Raises:

Type Description
FileNotFoundError

If checkpoint_path does not exist.

RuntimeError

If state_dict loading fails (architecture mismatch).

ValueError

If ONNX validation fails (when validate=True).

Example

export_to_onnx( ... model=EfficientNet(), ... checkpoint_path=Path("outputs/best_model.pth"), ... output_path=Path("exports/model.onnx"), ... input_shape=(3, 224, 224), ... )

Source code in orchard/export/onnx_exporter.py
def export_to_onnx(
    model: nn.Module,
    checkpoint_path: Path,
    output_path: Path,
    input_shape: tuple[int, int, int],
    opset_version: int = 18,
    dynamic_axes: bool = True,
    do_constant_folding: bool = True,
    validate: bool = True,
) -> None:
    """
    Export trained PyTorch model to ONNX format.

    Args:
        model: PyTorch model architecture (uninitialized weights OK)
        checkpoint_path: Path to trained .pth checkpoint
        output_path: Output path for .onnx file
        input_shape: Input tensor shape (C, H, W)
        opset_version: ONNX opset version (default: 18)
        dynamic_axes: Enable dynamic batch size (required for production)
        do_constant_folding: Optimize constant operations at export
        validate: Validate exported model with ONNX checker

    Raises:
        FileNotFoundError: If checkpoint_path does not exist.
        RuntimeError: If state_dict loading fails (architecture mismatch).
        ValueError: If ONNX validation fails (when validate=True).

    Example:
        >>> export_to_onnx(
        ...     model=EfficientNet(),
        ...     checkpoint_path=Path("outputs/best_model.pth"),
        ...     output_path=Path("exports/model.onnx"),
        ...     input_shape=(3, 224, 224),
        ... )
    """
    logger.info("  [Source]")
    logger.info("    %s Checkpoint        : %s", LogStyle.BULLET, checkpoint_path.name)
    logger.info("")

    # Create output directory if needed
    output_path.parent.mkdir(parents=True, exist_ok=True)

    # Load trained weights
    checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
    if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model.load_state_dict(checkpoint)

    # Move model to CPU for ONNX export
    model.cpu()
    model.eval()

    # Create dummy input (batch_size=1 for tracing)
    dummy_input = torch.randn(1, *input_shape)

    logger.info("  [Export Settings]")
    logger.info("    %s Format            : ONNX (opset %s)", LogStyle.BULLET, opset_version)
    logger.info("    %s Input shape       : %s", LogStyle.BULLET, tuple(dummy_input.shape))
    logger.info("    %s Dynamic axes      : %s", LogStyle.BULLET, dynamic_axes)
    logger.info("")

    # Prepare dynamic axes configuration
    if dynamic_axes:
        dynamic_axes_config = {
            "input": {0: "batch_size"},
            "output": {0: "batch_size"},
        }
    else:
        dynamic_axes_config = None

    # Export to ONNX (suppress verbose PyTorch warnings for cleaner output)
    # Temporarily suppress torch.onnx internal loggers to avoid roi_align warnings
    onnx_loggers = [
        logging.getLogger("torch.onnx._internal.exporter._schemas"),
        logging.getLogger("torch.onnx._internal.exporter"),
    ]
    original_levels = [log.level for log in onnx_loggers]

    try:
        # Raise log level to ERROR to suppress WARNING messages
        for onnx_logger in onnx_loggers:
            onnx_logger.setLevel(logging.ERROR)

        with (
            warnings.catch_warnings(),
            contextlib.redirect_stdout(io.StringIO()),
            contextlib.redirect_stderr(io.StringIO()),
        ):
            # Suppress warnings and stdout prints (e.g. ONNX rewrite rules)
            warnings.simplefilter("ignore")

            torch.onnx.export(
                model,
                (dummy_input,),  # Wrap in tuple for mypy type checking
                str(output_path),
                export_params=True,
                opset_version=opset_version,
                do_constant_folding=do_constant_folding,
                input_names=["input"],
                output_names=["output"],
                dynamic_axes=dynamic_axes_config,
                verbose=False,
            )
    finally:
        # Restore original log levels
        for onnx_logger, original_level in zip(onnx_loggers, original_levels):
            onnx_logger.setLevel(original_level)

    # Validate exported model
    if validate:
        try:
            import onnx

            onnx_model = onnx.load(str(output_path))
            onnx.checker.check_model(onnx_model)

            logger.info("  [Validation]")
            logger.info("    %s ONNX check        : %s Valid", LogStyle.BULLET, LogStyle.SUCCESS)
            size_mb = _onnx_file_size_mb(output_path)
            logger.info("    %s Model size        : %.2f MB", LogStyle.BULLET, size_mb)

        except ImportError:
            logger.warning(
                "    %s onnx package not installed. Skipping validation.", LogStyle.WARNING
            )
        except (ValueError, RuntimeError) as e:
            logger.error("    %s ONNX validation failed: %s", LogStyle.FAILURE, e)
            if output_path.exists():
                output_path.unlink()
                logger.info("    %s Cleaned up invalid ONNX file", LogStyle.ARROW)
            raise

    logger.info("")

quantize_model(onnx_path, output_path=None, backend='qnnpack', weight_type='int8')

Apply dynamic post-training quantization to an ONNX model.

Dispatches to 8-bit or 4-bit quantize_dynamic based on weight_type. INT4/UINT4 quantize only Gemm/MatMul nodes (Linear layers), leaving Conv layers at full precision.

Parameters:

Name Type Description Default
onnx_path Path

Path to the exported ONNX model

required
output_path Path | None

Path for quantized model (defaults to model_quantized.onnx in the same directory)

None
backend str

Quantization backend ("qnnpack" for mobile/ARM, "fbgemm" for x86)

'qnnpack'
weight_type str

Weight quantization type — "int8", "uint8", "int4", or "uint4"

'int8'

Returns:

Type Description
Path | None

Path to the quantized ONNX model, or None if quantization failed

Example

quantized = quantize_model(Path("exports/model.onnx")) print(f"Quantized model: {quantized}")

Source code in orchard/export/onnx_exporter.py
def quantize_model(
    onnx_path: Path,
    output_path: Path | None = None,
    backend: str = "qnnpack",
    weight_type: str = "int8",
) -> Path | None:
    """
    Apply dynamic post-training quantization to an ONNX model.

    Dispatches to 8-bit or 4-bit ``quantize_dynamic`` based on
    *weight_type*.  INT4/UINT4 quantize only Gemm/MatMul nodes (Linear
    layers), leaving Conv layers at full precision.

    Args:
        onnx_path: Path to the exported ONNX model
        output_path: Path for quantized model (defaults to model_quantized.onnx
                     in the same directory)
        backend: Quantization backend ("qnnpack" for mobile/ARM, "fbgemm" for x86)
        weight_type: Weight quantization type — "int8", "uint8", "int4", or "uint4"

    Returns:
        Path to the quantized ONNX model, or None if quantization failed

    Example:
        >>> quantized = quantize_model(Path("exports/model.onnx"))
        >>> print(f"Quantized model: {quantized}")
    """
    if output_path is None:
        output_path = onnx_path.parent / "model_quantized.onnx"

    logger.info("  [Quantization]")
    logger.info("    %s Backend           : %s", LogStyle.BULLET, backend)
    logger.info("    %s Weight type       : %s", LogStyle.BULLET, weight_type)

    try:
        if weight_type in ("int4", "uint4"):
            _quantize_4bit(onnx_path, output_path, weight_type)
        else:
            _quantize_8bit(onnx_path, output_path, backend, weight_type)
    except ImportError:
        logger.warning(
            "    %s onnxruntime.quantization not available. Skipping quantization.",
            LogStyle.WARNING,
        )
        return None
    except Exception as e:  # onnxruntime raises non-standard exceptions
        logger.error("    %s Quantization failed: %s", LogStyle.FAILURE, e)
        if output_path.exists():
            output_path.unlink()
        return None

    original_mb = _onnx_file_size_mb(onnx_path)
    quantized_mb = _onnx_file_size_mb(output_path)
    ratio = original_mb / quantized_mb if quantized_mb > 0 else 0

    logger.info(
        "    %s Size              : %.2f MB → %.2f MB (%.1fx)",
        LogStyle.BULLET,
        original_mb,
        quantized_mb,
        ratio,
    )
    logger.info("    %s Status            : %s Done", LogStyle.BULLET, LogStyle.SUCCESS)
    logger.info("")

    return output_path

benchmark_onnx_inference(onnx_path, input_shape, num_runs=100, seed=42, label='ONNX')

Benchmark ONNX model inference speed.

Parameters:

Name Type Description Default
onnx_path Path

Path to ONNX model

required
input_shape tuple[int, int, int]

Input tensor shape (C, H, W)

required
num_runs int

Number of inference runs for averaging

100
seed int

Random seed for reproducible dummy input

42
label str

Display label for the benchmark log header

'ONNX'

Returns:

Type Description
float

Average inference time in milliseconds

Example

latency = benchmark_onnx_inference(Path("model.onnx")) print(f"Latency: {latency:.2f}ms")

Source code in orchard/export/onnx_exporter.py
def benchmark_onnx_inference(
    onnx_path: Path,
    input_shape: tuple[int, int, int],
    num_runs: int = 100,
    seed: int = 42,
    label: str = "ONNX",
) -> float:
    """
    Benchmark ONNX model inference speed.

    Args:
        onnx_path: Path to ONNX model
        input_shape: Input tensor shape (C, H, W)
        num_runs: Number of inference runs for averaging
        seed: Random seed for reproducible dummy input
        label: Display label for the benchmark log header

    Returns:
        Average inference time in milliseconds

    Example:
        >>> latency = benchmark_onnx_inference(Path("model.onnx"))
        >>> print(f"Latency: {latency:.2f}ms")
    """
    try:
        import time

        import numpy as np
        import onnxruntime as ort

        logger.info("  [Benchmark — %s]", label)

        # Create inference session
        session = ort.InferenceSession(str(onnx_path))

        # Prepare dummy input using N(0,1) (matches validation distribution)
        rng = np.random.default_rng(seed)
        dummy_input = rng.standard_normal(size=(1, *input_shape)).astype(np.float32)

        # Warmup
        for _ in range(10):
            session.run(None, {"input": dummy_input})

        # Benchmark
        start = time.time()
        for _ in range(num_runs):
            session.run(None, {"input": dummy_input})
        elapsed = time.time() - start

        avg_latency_ms = (elapsed / num_runs) * 1000
        logger.info("    %s Runs              : %s", LogStyle.BULLET, num_runs)
        logger.info("    %s Avg latency       : %.2fms", LogStyle.BULLET, avg_latency_ms)
        logger.info("")

        return avg_latency_ms

    except ImportError:
        logger.warning("onnxruntime not installed. Skipping benchmark.")
        return -1.0
    except Exception as e:  # onnxruntime raises non-standard exceptions
        logger.error("Benchmark failed: %s", e)
        return -1.0