Source code for ml.utils.testing

"""Utility function for testing datasets.

All this does is iterates through some samples in a dataset or dataloder. It's
useful when developing a dataset because you can just add a small code snippet
to the bottom of your file like so:

.. code-block:: python

    if __name__ == "__main__":
        from ml.utils.testing import test_dataset, test_task

        test_dataset(MyDataset())
        test_task(MyTask(MyTaskConfig()))
"""

import itertools
import logging
import time
from typing import Any, Callable, TypeVar

import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset, IterableDataset

from ml.core.state import State
from ml.models.base import BaseModel, BaseModelConfig
from ml.tasks.base import BaseTask
from ml.tasks.rl.base import ReinforcementLearningTask
from ml.tasks.sl.base import SupervisedLearningTask
from ml.utils.checkpoint import is_missing
from ml.utils.logging import configure_logging
from ml.utils.timer import Timer

T = TypeVar("T")

logger: logging.Logger = logging.getLogger(__name__)


[docs]def test_dataset( ds: Dataset[T] | IterableDataset[T] | DataLoader[T], max_samples: int = 3, log_interval: int = 10, callback: Callable[[T], None] | None = None, ) -> None: """Iterates through a dataset. Args: ds: The dataset to iterate through max_samples: Maximum number of samples to loop through log_interval: How often to log the time it takes to load a sample callback: A callback to run on each sample """ configure_logging() start_time = time.time() if isinstance(ds, (IterableDataset, DataLoader)): logger.info("Iterating samples in %s", "dataloader" if isinstance(ds, DataLoader) else "dataset") for i, sample in enumerate(itertools.islice(ds, max_samples)): if callback is not None: callback(sample) if i % log_interval == 0: logger.info("Sample %d in %.2g seconds", i, time.time() - start_time) else: samples = len(ds) # type: ignore[arg-type] logger.info("Dataset has %d items", samples) for i in range(min(samples, max_samples)): sample = ds[i] if callback is not None: callback(sample) if i % log_interval == 0: logger.info("Sample %d in %.2g seconds", i, time.time() - start_time)
[docs]class DummyModelConfig(BaseModelConfig): pass
[docs]class DummyModel(BaseModel[BaseModelConfig]):
[docs] def forward(self, *args: Any, **_kwargs: Any) -> Any: # noqa: ANN401 return args
[docs]def test_task( task: BaseTask, model: BaseModel | None = None, max_samples: int = 3, log_interval: int = 10, default_batch_size: int = 4, ) -> None: """Runs some adhoc tests on a task. This is useful for testing a task while developing it, by running through the various parts. Args: task: The task to test. model: The model to use for testing the task. If not provided, a dummy model will be created. max_samples: Maximum number of samples to loop through, for testing the raw dataset log_interval: How often to log the time it takes to load a sample, for testing the raw dataset default_batch_size: The default batch size to use for testing the dataloaders """ configure_logging() start_time = time.time() if is_missing(task.config.train_dl, "batch_size") or not isinstance(task.config.train_dl.batch_size, int): task.config.train_dl.batch_size = default_batch_size if is_missing(task.config.valid_dl, "batch_size") or not isinstance(task.config.valid_dl.batch_size, int): task.config.valid_dl.batch_size = 1 if is_missing(task.config.train_dl, "batch_size") or not isinstance(task.config.test_dl.batch_size, int): task.config.test_dl.batch_size = 1 if is_missing(task.config.train_dl, "num_workers") or not isinstance(task.config.train_dl.num_workers, int): task.config.train_dl.num_workers = 2 with Timer("initializing model"): if model is None: model = DummyModel(DummyModelConfig()) logger.warning("Testing task using dummy Identity model") else: num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info("Model has %s trainable parameters", f"{num_trainable_params:,}") device, dtype = task._device._get_device(), torch.float32 model.init(device, dtype) task.to(device, dtype) if isinstance(task, SupervisedLearningTask): with Timer("building dataset", min_seconds_to_print=0.0): ds = task.get_dataset("train") with Timer("testing dataset", min_seconds_to_print=0.0): test_dataset(ds, max_samples, log_interval) with Timer("building dataloader", min_seconds_to_print=0.0): dl = task.get_dataloader(ds, "train") with Timer("building prefetcher", min_seconds_to_print=0.0): pf = task._device.get_prefetcher(dl) with Timer("getting sample", min_seconds_to_print=0.0): sample = next(iter(pf)) state = State.init_state() output = task.run_model(model, sample, state) loss = task.compute_loss(model, sample, state, output) logger.info("Computed loss (%s) in %.2g seconds", loss, time.time() - start_time) elif isinstance(task, ReinforcementLearningTask): with Timer("getting worker pool", min_seconds_to_print=0.0): worker_pool = task.get_worker_pool() with Timer("getting samples", min_seconds_to_print=0.0): samples = task.collect_samples(model, worker_pool, 10) with Timer("building dataset", min_seconds_to_print=0.0): ds = task.build_rl_dataset(samples) with Timer("building dataloader", min_seconds_to_print=0.0): dl = task.get_dataloader(ds, "train") with Timer("building prefetcher", min_seconds_to_print=0.0): pf = task._device.get_prefetcher(dl) with Timer("getting sample", min_seconds_to_print=0.0): sample = next(iter(pf)) state = State.init_state() output = task.run_model(model, sample, state) loss = task.compute_loss(model, sample, state, output) logger.info("Computed loss (%s) in %.2g seconds", loss, time.time() - start_time) else: raise NotImplementedError(f"Testing not implemented for task type {task}")