Skip to content

dataset

orchard.data_handler.dataset

PyTorch Dataset Definition Module.

This module contains the custom Dataset class for NPZ-based vision datasets, handling the conversion from NumPy arrays to PyTorch tensors and applying image transformations for training and inference.

It supports two loading strategies via classmethod factories:

  • from_npz: Eager loading into RAM with transforms, subsampling, and PIL conversion.
  • lazy: Memory-mapped loading for large datasets or lightweight health checks.

Key Components:

  • VisionDataset: Full-featured dataset with eager and lazy loading modes.

VisionDataset(images, labels, *, transform=None)

Bases: Dataset[tuple[Tensor, Tensor]]

PyTorch Dataset for NPZ-based vision data.

The constructor accepts raw NumPy arrays directly (no I/O). Use the classmethod factories to load from disk:

  • VisionDataset.from_npz(...) — eager, full split into RAM.
  • VisionDataset.lazy(...) — memory-mapped, pages loaded on demand.

Initializes the dataset from pre-loaded arrays.

Parameters:

Name Type Description Default
images NDArray[Any]

Image array with shape (N, H, W) or (N, H, W, C).

required
labels NDArray[Any]

Label array, any shape that flattens to (N,).

required
transform Compose | None

Pipeline of Torchvision transforms.

None
Source code in orchard/data_handler/dataset.py
def __init__(
    self,
    images: npt.NDArray[Any],
    labels: npt.NDArray[Any],
    *,
    transform: transforms.Compose | None = None,
) -> None:
    """
    Initializes the dataset from pre-loaded arrays.

    Args:
        images: Image array with shape ``(N, H, W)`` or ``(N, H, W, C)``.
        labels: Label array, any shape that flattens to ``(N,)``.
        transform: Pipeline of Torchvision transforms.
    """
    # Ensure consistent (N, H, W, C) for PIL conversion
    if images.ndim == 3:  # (N, H, W) -> (N, H, W, 1)
        images = np.expand_dims(images, axis=-1)

    self.images = images
    self.labels: npt.NDArray[Any] = labels.ravel().astype(np.int64)
    self.transform = transform

    # Kept alive to prevent GC of mmap arrays (set by .lazy())
    self._npz_handle: np.lib.npyio.NpzFile | None = None
    # Index mapping for lazy subsampling (None = use all)
    self._indices: npt.NDArray[Any] | None = None

from_npz(path, split='train', *, transform=None, max_samples=None, seed=42) classmethod

Eagerly load a split from an NPZ archive into RAM.

Parameters:

Name Type Description Default
path Path

Path to the dataset .npz archive.

required
split str

Dataset split to load (train, val, or test).

'train'
transform Compose | None

Pipeline of Torchvision transforms.

None
max_samples int | None

If set, limits the number of samples (subsampling).

None
seed int

Random seed for deterministic subsampling.

42
Source code in orchard/data_handler/dataset.py
@classmethod
def from_npz(
    cls,
    path: Path,
    split: str = "train",
    *,
    transform: transforms.Compose | None = None,
    max_samples: int | None = None,
    seed: int = 42,
) -> VisionDataset:
    """
    Eagerly load a split from an NPZ archive into RAM.

    Args:
        path: Path to the dataset ``.npz`` archive.
        split: Dataset split to load (``train``, ``val``, or ``test``).
        transform: Pipeline of Torchvision transforms.
        max_samples: If set, limits the number of samples (subsampling).
        seed: Random seed for deterministic subsampling.
    """
    if not path.exists():
        raise OrchardDatasetError(f"Dataset file not found at: {path}")

    with np.load(path) as data:
        raw_images = data[f"{split}_images"]
        raw_labels = data[f"{split}_labels"]

        total_available = len(raw_labels)

        # Deterministic subsampling logic
        if max_samples and max_samples < total_available:
            rng = np.random.default_rng(seed)
            chosen = rng.choice(total_available, size=max_samples, replace=False)
            images = raw_images[chosen]
            labels = raw_labels[chosen]
        else:
            images = np.array(raw_images)
            labels = raw_labels

    return cls(images, labels, transform=transform)

lazy(path, split='train', *, transform=None, max_samples=None, seed=42) classmethod

Memory-mapped load from an NPZ archive (no full RAM copy).

Images are loaded page-by-page on demand. Suitable for large datasets that do not fit in RAM and for lightweight health checks.

Parameters:

Name Type Description Default
path Path

Path to the .npz file.

required
split str

Dataset split to load (default train).

'train'
transform Compose | None

Pipeline of Torchvision transforms.

None
max_samples int | None

If set, limits the number of samples (subsampling).

None
seed int

Random seed for deterministic subsampling.

42
Source code in orchard/data_handler/dataset.py
@classmethod
def lazy(
    cls,
    path: Path,
    split: str = "train",
    *,
    transform: transforms.Compose | None = None,
    max_samples: int | None = None,
    seed: int = 42,
) -> VisionDataset:
    """
    Memory-mapped load from an NPZ archive (no full RAM copy).

    Images are loaded page-by-page on demand. Suitable for large datasets
    that do not fit in RAM and for lightweight health checks.

    Args:
        path: Path to the ``.npz`` file.
        split: Dataset split to load (default ``train``).
        transform: Pipeline of Torchvision transforms.
        max_samples: If set, limits the number of samples (subsampling).
        seed: Random seed for deterministic subsampling.
    """
    data = np.load(path, mmap_mode="r")
    instance = cls(data[f"{split}_images"], data[f"{split}_labels"], transform=transform)
    instance._npz_handle = data

    if max_samples and max_samples < len(instance.labels):
        rng = np.random.default_rng(seed)
        instance._indices = rng.choice(len(instance.labels), size=max_samples, replace=False)
        # Eagerly subsample labels (small) so .labels and __len__ stay consistent
        instance.labels = instance.labels[instance._indices]

    return instance

__len__()

Returns the total number of samples currently in the dataset.

Source code in orchard/data_handler/dataset.py
def __len__(self) -> int:
    """Returns the total number of samples currently in the dataset."""
    return len(self.labels)

__getitem__(idx)

Retrieves a standardized sample-label pair.

The image is converted to a PIL object to ensure compatibility with Torchvision V2 transforms before being returned as a PyTorch Tensor.

Parameters:

Name Type Description Default
idx int

Sample index.

required

Returns:

Type Description
Tensor

A pair of (image, label) where image is a (C, H, W) float

Tensor

tensor and label is a scalar long tensor.

Source code in orchard/data_handler/dataset.py
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Retrieves a standardized sample-label pair.

    The image is converted to a PIL object to ensure compatibility with
    Torchvision V2 transforms before being returned as a PyTorch Tensor.

    Args:
        idx: Sample index.

    Returns:
        A pair of (image, label) where image is a ``(C, H, W)`` float
        tensor and label is a scalar long tensor.
    """
    # Remap index for lazy subsampling (images stay full mmap)
    img_idx = self._indices[idx] if self._indices is not None else idx
    img = self.images[img_idx]

    pil_img = Image.fromarray(img.squeeze() if img.shape[-1] == 1 else img)

    if self.transform:
        img_t = self.transform(pil_img)
    else:
        img_t = transforms.functional.to_tensor(pil_img)

    return img_t, torch.tensor(int(self.labels[idx]), dtype=torch.long)