onnx_exporter
orchard.export.onnx_exporter
¶
ONNX Export, Quantization, and Benchmarking.
End-to-end production pipeline for converting trained PyTorch checkpoints
to optimized ONNX graphs. The module is consumed by the CLI orchard
export command and operates entirely on CPU.
Key Functions:
export_to_onnx: Trace-based export with dynamic batch axes, constant folding, and optionalonnx.checkervalidation.quantize_model: Dynamic post-training quantization (INT8, UINT8, INT4, UINT4) via onnxruntime (qnnpack for ARM, fbgemm for x86).benchmark_onnx_inference: Warm-up + timed inference loop returning average latency in milliseconds.
export_to_onnx(model, checkpoint_path, output_path, input_shape, opset_version=18, dynamic_axes=True, do_constant_folding=True, validate=True)
¶
Export trained PyTorch model to ONNX format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
PyTorch model architecture (uninitialized weights OK) |
required |
checkpoint_path
|
Path
|
Path to trained .pth checkpoint |
required |
output_path
|
Path
|
Output path for .onnx file |
required |
input_shape
|
tuple[int, int, int]
|
Input tensor shape (C, H, W) |
required |
opset_version
|
int
|
ONNX opset version (default: 18) |
18
|
dynamic_axes
|
bool
|
Enable dynamic batch size (required for production) |
True
|
do_constant_folding
|
bool
|
Optimize constant operations at export |
True
|
validate
|
bool
|
Validate exported model with ONNX checker |
True
|
Raises:
| Type | Description |
|---|---|
FileNotFoundError
|
If checkpoint_path does not exist. |
RuntimeError
|
If state_dict loading fails (architecture mismatch). |
ValueError
|
If ONNX validation fails (when validate=True). |
Example
export_to_onnx( ... model=EfficientNet(), ... checkpoint_path=Path("outputs/best_model.pth"), ... output_path=Path("exports/model.onnx"), ... input_shape=(3, 224, 224), ... )
Source code in orchard/export/onnx_exporter.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | |
quantize_model(onnx_path, output_path=None, backend='qnnpack', weight_type='int8')
¶
Apply dynamic post-training quantization to an ONNX model.
Dispatches to 8-bit or 4-bit quantize_dynamic based on
weight_type. INT4/UINT4 quantize only Gemm/MatMul nodes (Linear
layers), leaving Conv layers at full precision.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
onnx_path
|
Path
|
Path to the exported ONNX model |
required |
output_path
|
Path | None
|
Path for quantized model (defaults to model_quantized.onnx in the same directory) |
None
|
backend
|
str
|
Quantization backend ("qnnpack" for mobile/ARM, "fbgemm" for x86) |
'qnnpack'
|
weight_type
|
str
|
Weight quantization type — "int8", "uint8", "int4", or "uint4" |
'int8'
|
Returns:
| Type | Description |
|---|---|
Path | None
|
Path to the quantized ONNX model, or None if quantization failed |
Example
quantized = quantize_model(Path("exports/model.onnx")) print(f"Quantized model: {quantized}")
Source code in orchard/export/onnx_exporter.py
benchmark_onnx_inference(onnx_path, input_shape, num_runs=100, seed=42, label='ONNX')
¶
Benchmark ONNX model inference speed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
onnx_path
|
Path
|
Path to ONNX model |
required |
input_shape
|
tuple[int, int, int]
|
Input tensor shape (C, H, W) |
required |
num_runs
|
int
|
Number of inference runs for averaging |
100
|
seed
|
int
|
Random seed for reproducible dummy input |
42
|
label
|
str
|
Display label for the benchmark log header |
'ONNX'
|
Returns:
| Type | Description |
|---|---|
float
|
Average inference time in milliseconds |
Example
latency = benchmark_onnx_inference(Path("model.onnx")) print(f"Latency: {latency:.2f}ms")