Skip to content

PyTorch DataLoader

wsistream provides WsiStreamDataset, an IterableDataset that wraps PatchPipeline and handles multi-worker slide partitioning automatically.

pip install -e ".[torch]"

Basic usage

from torch.utils.data import DataLoader

from wsistream.backends import OpenSlideBackend
from wsistream.sampling import RandomSampler
from wsistream.tissue import OtsuTissueDetector
from wsistream.torch import WsiStreamDataset

dataset = WsiStreamDataset(
    slide_paths=slide_paths,
    backend=OpenSlideBackend(),
    tissue_detector=OtsuTissueDetector(),
    sampler=RandomSampler(patch_size=256, target_mpp=0.5),
    pool_size=8,
    patches_per_slide=100,
    # replacement="without_replacement",  # optional: no repeated coords per slide
)

loader = DataLoader(dataset, batch_size=64, num_workers=4, pin_memory=True)

for batch in loader:
    images = batch["image"]            # (B, 3, H, W) float32
    x = batch["x"]                     # (B,) int — level-0 x coordinates
    y = batch["y"]                     # (B,) int — level-0 y coordinates
    mpp = batch["mpp"]                 # (B,) float — microns/px, -1.0 if unavailable
    tf = batch["tissue_fraction"]      # (B,) float
    paths = batch["slide_path"]        # list[str], length B
    patient = batch["patient_id"]      # list[str], length B (empty if no adapter)

Each batch is a dict of primitives and tensors. Image conversion (HWC uint8 → CHW float32, divided by 255) is handled internally. If a NormalizeTransform is included in the transforms chain, the image is already float32 and is passed through without re-scaling — values will reflect the normalization (e.g., roughly [-2, 3] for ImageNet stats), not [0, 1].

With multi-view datasets, each view is collated under its configured name:

batch["global_0"]  # (B, 3, 224, 224)
batch["global_1"]  # (B, 3, 224, 224)
batch["local_0"]  # (B, 3, 96, 96)

Coordinate and metadata fields follow the same schema. See Views for multi-view configuration examples.

Default slide ordering

WsiStreamDataset defaults to slide_sampling="random" (better for training diversity), while PatchPipeline defaults to "sequential". If you need deterministic slide order through the dataset wrapper, pass slide_sampling="sequential" explicitly.

Deterministic validation

If you want the same validation patches every time, use a fixed seed, slide_sampling="sequential", cycle=False, and num_workers=0 on the validation DataLoader. With num_workers>0, PatchPipeline mixes the worker PID into RNG seeds so repeated validation runs are not bit-exact across calls.

Why IterableDataset, not Dataset?

A map-style Dataset requires __len__ and __getitem__. Online patching is inherently stochastic -- there is no fixed set of patches to index. IterableDataset streams lazily, which is what online patching needs. See the PyTorch data loading docs for background on the two dataset styles.

Step-based training

With cycle=True (the default in WsiStreamDataset), the pipeline produces an infinite stream. Since patches are randomly sampled from tissue regions, there is no guarantee of seeing the same patches twice — a traditional "epoch" is not meaningful. Training is defined by a number of steps:

loader_iter = iter(loader)

for step in range(total_steps):
    batch = next(loader_iter)
    images = batch["image"].to(device, non_blocking=True)

    loss = model(images)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

See Online Patching for why there are no epochs.

Logging and throughput monitoring

Wrap the DataLoader with MonitoredLoader to automatically track data wait time, compute time, and throughput. It also merges dataset.stats_dict() into each payload:

from wsistream.torch import MonitoredLoader

mon = MonitoredLoader(loader, dataset=dataset, device=device, log_every=100)

for step, batch in enumerate(mon):
    images = batch["image"].to(device, non_blocking=True)
    loss = model(images).mean()  # placeholder — replace with your actual loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    payload = mon.mark_step(extra={"train/loss": float(loss.detach())})
    if payload is not None:
        wandb.log(payload, step=step)

See Weights & Biases for details on what metrics are included.

Contiguous arrays

Numpy arrays from np.flip or np.rot90 (used by RandomFlipRotate) may not be contiguous in memory, which causes torch.from_numpy to fail. WsiStreamDataset handles this internally with np.ascontiguousarray().

Full example

See examples/train_single_gpu.py and examples/train_ddp.py in the repository for complete working examples.