Skip to content

validation

orchard.export.validation

Export Numerical Validation.

Verifies that the ONNX graph produces outputs numerically equivalent to the original PyTorch model. validate_export runs N random forward passes through both runtimes (PyTorch on CPU, onnxruntime CPUExecutionProvider) and asserts that the maximum absolute deviation stays below a configurable threshold (default 1e-4). Called automatically by the export pipeline when validate=True.

validate_export(pytorch_model, onnx_path, input_shape, num_samples=10, max_deviation=0.0001, label='ONNX')

Validate ONNX export against PyTorch model.

Compares outputs from PyTorch and ONNX models on random inputs to ensure numerical consistency after export.

Parameters:

Name Type Description Default
pytorch_model Module

Original PyTorch model (with loaded weights)

required
onnx_path Path

Path to exported ONNX model

required
input_shape tuple[int, int, int]

Input tensor shape (C, H, W)

required
num_samples int

Number of random samples to test

10
max_deviation float

Maximum allowed absolute difference

0.0001
label str

Display label for log header (e.g. "ONNX", "Quantized")

'ONNX'

Returns:

Type Description
bool | None

True if validation passes, False if outputs diverge,

bool | None

None if skipped (onnxruntime not installed).

Raises:

Type Description
OrchardExportError

If the ONNX file does not exist.

Example

model.load_state_dict(torch.load("checkpoint.pth")) valid = validate_export(model, Path("model.onnx"), input_shape=(3, 224, 224)) if valid: ... print("Export validated successfully!")

Source code in orchard/export/validation.py
def validate_export(
    pytorch_model: nn.Module,
    onnx_path: Path,
    input_shape: tuple[int, int, int],
    num_samples: int = 10,
    max_deviation: float = 1e-4,
    label: str = "ONNX",
) -> bool | None:
    """
    Validate ONNX export against PyTorch model.

    Compares outputs from PyTorch and ONNX models on random inputs
    to ensure numerical consistency after export.

    Args:
        pytorch_model: Original PyTorch model (with loaded weights)
        onnx_path: Path to exported ONNX model
        input_shape: Input tensor shape (C, H, W)
        num_samples: Number of random samples to test
        max_deviation: Maximum allowed absolute difference
        label: Display label for log header (e.g. "ONNX", "Quantized")

    Returns:
        True if validation passes, False if outputs diverge,
        None if skipped (onnxruntime not installed).

    Raises:
        OrchardExportError: If the ONNX file does not exist.

    Example:
        >>> model.load_state_dict(torch.load("checkpoint.pth"))
        >>> valid = validate_export(model, Path("model.onnx"), input_shape=(3, 224, 224))
        >>> if valid:
        ...     print("Export validated successfully!")
    """
    # Check if ONNX file exists
    if not onnx_path.exists():
        raise OrchardExportError(f"ONNX model not found: {onnx_path}")

    try:
        import onnxruntime as ort

        logger.info("%s[%s Validation]", LogStyle.INDENT, label)
        logger.info("    %s Samples           : %s", LogStyle.BULLET, num_samples)
        logger.info("    %s Max deviation     : %.0e", LogStyle.BULLET, max_deviation)

        # Load ONNX model (force CPU to match export conditions)
        session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])

        pytorch_model.eval()
        pytorch_model.cpu()

        max_diff = 0.0
        g = torch.Generator(device="cpu")
        g.manual_seed(0)

        with torch.no_grad():
            for i in range(num_samples):
                # Generate deterministic random input for reproducible validation
                x_torch = torch.randn(1, *input_shape, generator=g)
                x_numpy = x_torch.numpy().astype(np.float32)

                # PyTorch inference
                y_torch = pytorch_model(x_torch).numpy()

                # ONNX inference
                y_onnx = session.run(None, {"input": x_numpy})[0]

                # Shape guard: detect export shape mismatches early
                if y_torch.shape != y_onnx.shape:
                    raise OrchardExportError(
                        f"Output shape mismatch: PyTorch {y_torch.shape} vs ONNX {y_onnx.shape}"
                    )

                # Compare outputs
                diff = np.abs(y_torch - y_onnx).max()
                max_diff = max(max_diff, diff)

                if diff > max_deviation:
                    logger.error(
                        "    %s Result            : %s FAILED sample %d (diff: %.2e, threshold: %.2e)",
                        LogStyle.BULLET,
                        LogStyle.WARNING,
                        i + 1,
                        diff,
                        max_deviation,
                    )
                    return False

        logger.info(
            "    %s Result            : %s Passed (max diff: %.2e)",
            LogStyle.BULLET,
            LogStyle.SUCCESS,
            max_diff,
        )
        logger.info("")
        return True

    except ImportError as e:
        logger.warning("onnxruntime not installed. Skipping validation: %s", e)
        return None
    except (RuntimeError, ValueError) as e:
        logger.error("Validation failed: %s", e, exc_info=True)
        raise