Weights & Biases
Both PatchPipeline and WsiStreamDataset provide a stats_dict() method that returns a flat dictionary ready for wandb.log(). For PyTorch training loops, MonitoredLoader adds automatic throughput and timing metrics.
With PatchPipeline
For non-PyTorch workflows or custom training loops:
import wandb
from wsistream.datasets import TCGAAdapter
from wsistream.pipeline import PatchPipeline
wandb.init(project="pathology-fm")
pipeline = PatchPipeline(
slide_paths="/data/tcga",
dataset_adapter=TCGAAdapter(),
...
)
for step, result in enumerate(pipeline):
# ... training step ...
if step % 500 == 0:
wandb.log(pipeline.stats_dict(), step=step)
pipeline.stats_dict() includes detailed per-slide statistics:
pipeline.stats_dict()
# {
# "pipeline/slides_processed": 42,
# "pipeline/slides_failed": 1,
# "pipeline/slides_unique": 40,
# "pipeline/patches_extracted": 4200,
# "pipeline/patches_filtered": 380,
# "pipeline/mean_tissue_fraction": 0.63,
# "pipeline/min_tissue_fraction": 0.21,
# "pipeline/max_tissue_fraction": 0.91,
# "pipeline/mpp_0.25": 1050, # ~40x
# "pipeline/mpp_0.50": 1050, # ~20x
# "pipeline/mpp_1.00": 1050, # ~10x
# "pipeline/mpp_2.00": 1050, # ~5x
# "pipeline/cancer_type/TCGA-BRCA": 15,
# "pipeline/cancer_type/TCGA-LUAD": 12,
# "pipeline/sample_type/primary_solid_tumor": 38,
# "pipeline/sample_type/solid_tissue_normal": 4,
# "pipeline/error_count": 1,
# }
The magnification keys (mpp_*) are only present when patches are extracted at known MPP values. The cancer/sample type keys are only present when a DatasetAdapter is configured. error_count is only present when errors have occurred.
For programmatic access to raw data, use the stats property:
stats = pipeline.stats
stats.slides_processed # int — total opens (includes revisits)
stats.slides_failed # int
stats.slides_seen # set[str] — unique slide paths seen
stats.patches_extracted # int
stats.patches_filtered # int (rejected by PatchFilter)
stats.tissue_fractions.count # int — number of slides
stats.tissue_fractions.mean # float | None
stats.tissue_fractions.min_val # float
stats.tissue_fractions.max_val # float
stats.error_count # int — total errors recorded
stats.recent_errors # deque — most recent 100 (slide_path, message) pairs
Use pipeline.reset_stats() to clear accumulated counters at any point.
With WsiStreamDataset + MonitoredLoader (PyTorch)
MonitoredLoader wraps the DataLoader and automatically tracks data wait time, compute time, and throughput. It also includes dataset.stats_dict() in every payload:
import wandb
from torch.utils.data import DataLoader
from wsistream.torch import MonitoredLoader, WsiStreamDataset
wandb.init(project="pathology-fm")
dataset = WsiStreamDataset(
slide_paths="/data/tcga",
...
)
loader = DataLoader(dataset, batch_size=64, num_workers=4, pin_memory=True)
mon = MonitoredLoader(loader, dataset=dataset, device=device, log_every=100)
for step, batch in enumerate(mon):
images = batch["image"].to(device, non_blocking=True)
loss = model(images).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad()
payload = mon.mark_step(extra={"train/loss": float(loss)})
if payload is not None:
wandb.log(payload, step=step)
Each payload includes:
| Key | Description |
|---|---|
loader/data_wait_ms |
Average time waiting for the next batch (per step) |
loader/compute_ms |
Average time in the training step (per step) |
loader/step_ms |
Total step time (data + compute) |
loader/data_fraction |
Fraction of step time spent waiting for data |
loader/batches_per_sec |
Training step throughput |
loader/patches_per_sec |
Patch throughput |
pipeline/* |
All keys from dataset.stats_dict() |
If loader/data_fraction is consistently above 0.5, the data pipeline is the bottleneck — increase num_workers or check your storage I/O. See Benchmarking for how to find optimal settings before training.
Live stats accuracy
With num_workers=0, pipeline stats are exact on every step. With num_workers>0, stats are aggregated from worker processes and may lag by up to 16 patches per worker between flushes. Final totals at the end of iteration are always correct.
DDP
In DDP, each rank has its own MonitoredLoader. All ranks call mark_step() every step, but only rank 0 logs:
mon = MonitoredLoader(loader, dataset=dataset, device=device, log_every=100)
for step, batch in enumerate(mon):
# ... training step ...
payload = mon.mark_step(extra={"train/loss": float(loss)})
if payload is not None and rank == 0:
wandb.log(payload, step=step)
For per-rank straggler detection, log from every rank with a prefix: