Source code for ml.trainers.sl

"""Defines a trainer to use for supervised learning.

This trainer is akin to PyTorch Lightning or Keras, in that it handles
the training loop, logging, and checkpointing. We get a dataset and dataloader
from the task, and then train the model on the dataset.
"""

import bisect
import contextlib
import functools
import itertools
import logging
import signal
from dataclasses import dataclass
from types import FrameType
from typing import Generic, Iterator, TypeVar

from omegaconf import MISSING

from ml.core.common_types import Batch
from ml.core.config import conf_field
from ml.core.registry import register_trainer
from ml.core.state import State
from ml.lr_schedulers.base import BaseLRScheduler
from ml.models.gan import GenerativeAdversarialNetworkModel
from ml.optimizers.base import BaseOptimizer
from ml.tasks.gan.base import GenerativeAdversarialNetworkTask
from ml.tasks.gan.round_robin import GenerativeAdversarialNetworkRoundRobinTask
from ml.tasks.sl.base import SupervisedLearningTask
from ml.trainers.base import ModelT
from ml.trainers.learning import BaseLearningTrainer, BaseLearningTrainerConfig
from ml.utils.containers import recursive_chunk
from ml.utils.device.base import InfinitePrefetcher
from ml.utils.exceptions import EpochDoneError, TrainingFinishedError
from ml.utils.timer import Timer

logger: logging.Logger = logging.getLogger(__name__)


[docs]@dataclass class ValidationConfig: valid_every_n_steps: int | None = conf_field(None, help="Number of training steps to run per validation step") valid_every_n_seconds: float | None = conf_field(60.0 * 10.0, help="Run validation every N seconds") valid_first_n_seconds: float | None = conf_field(60.0, help="Run first validation after N seconds")
[docs]@dataclass class BatchScheduleConfig: num_steps: int = conf_field(MISSING, help="Number of steps to run for") num_batches: int = conf_field(MISSING, help="Number of minibatches for a given step")
[docs]@dataclass class SupervisedLearningTrainerConfig(BaseLearningTrainerConfig): validation: ValidationConfig = conf_field(ValidationConfig()) batches_per_step: int = conf_field(1, help="Batches per training step, to simulate larger effective batch sizes") batches_per_step_schedule: list[BatchScheduleConfig] | None = conf_field( None, help="A schedule for the number of minibatches per step, as a list of (step_count, num_batches) tuples.", ) batch_chunks_per_step_schedule: list[BatchScheduleConfig] | None = conf_field( None, help="A schedule for chunking the batches, as a list of (step_count, num_chunks) tuples.", ) batch_dim: int = conf_field(0, help="The batch dimension, for splitting batches into chunks")
SupervisedLearningTrainerConfigT = TypeVar("SupervisedLearningTrainerConfigT", bound=SupervisedLearningTrainerConfig) SupervisedLearningTaskT = TypeVar("SupervisedLearningTaskT", bound=SupervisedLearningTask)
[docs]@register_trainer("sl", SupervisedLearningTrainerConfig) class SupervisedLearningTrainer( BaseLearningTrainer[SupervisedLearningTrainerConfigT, ModelT, SupervisedLearningTaskT], Generic[SupervisedLearningTrainerConfigT, ModelT, SupervisedLearningTaskT], ): def __init__(self, config: SupervisedLearningTrainerConfigT) -> None: super().__init__(config) self.last_valid_time: float | None = None self.last_valid_step: int | None = None self.first_valid_step_flag = True
[docs] @functools.lru_cache() def batches_per_step_schedule(self) -> list[tuple[int, int]] | None: schedule = self.config.batches_per_step_schedule if schedule is None: return None if any(s.num_steps <= 0 or s.num_batches <= 0 for s in schedule): raise ValueError("steps and num_batches must be non-negative") schedule_list = [(s.num_steps, s.num_batches) for s in schedule] schedule_cumsum = list(itertools.accumulate([0] + [s[0] for s in schedule_list])) return list(zip(schedule_cumsum[1:], [s[1] for s in schedule_list]))
[docs] def get_batches_per_step(self, state: State) -> int: if (schedule := self.batches_per_step_schedule()) is None: return self.config.batches_per_step i = bisect.bisect_left(schedule, (state.num_steps, 0)) return schedule[-1][1] if i == len(schedule) else schedule[i][1]
[docs] @functools.lru_cache() def batch_chunks_schedule(self) -> list[tuple[int, int]] | None: schedule = self.config.batch_chunks_per_step_schedule if schedule is None: return None if any(s.num_steps <= 0 or s.num_batches <= 0 for s in schedule): raise ValueError("steps and num_batches must be non-negative") schedule_list = [(s.num_steps, s.num_batches) for s in schedule] schedule_cumsum = list(itertools.accumulate([0] + [s[0] for s in schedule_list])) return list(zip(schedule_cumsum[1:], [s[1] for s in schedule_list]))
[docs] def get_batch_chunks(self, state: State) -> int: if (schedule := self.batch_chunks_schedule()) is None: return 1 i = bisect.bisect_left(schedule, (state.num_steps, 0)) return 1 if i == len(schedule) else schedule[i][1]
[docs] def should_validate(self, state: State) -> bool: if self.last_valid_time is None or self.last_valid_step is None: self.last_valid_time = state.elapsed_time_s self.last_valid_step = state.num_steps return True # Step-based validation. valid_every_n_steps = self.config.validation.valid_every_n_steps if valid_every_n_steps is not None and state.num_steps > valid_every_n_steps + self.last_valid_step: self.last_valid_step = state.num_steps return True # Time-based validation. valid_every_n_seconds = self.config.validation.valid_every_n_seconds if valid_every_n_seconds is not None and state.elapsed_time_s - self.last_valid_time >= valid_every_n_seconds: self.last_valid_time = state.elapsed_time_s return True # Time-based validation for first validation step. if self.first_valid_step_flag: valid_first_n_seconds = self.config.validation.valid_first_n_seconds if valid_first_n_seconds is not None and state.elapsed_time_s >= valid_first_n_seconds: self.last_valid_time = state.elapsed_time_s self.first_valid_step_flag = False return True return False
[docs] def train( self, model: ModelT, task: SupervisedLearningTaskT, optimizer: BaseOptimizer, lr_scheduler: BaseLRScheduler, ) -> None: """Runs the training loop. Args: model: The current model task: The current task optimizer: The current optimizer lr_scheduler: The current learning rate scheduler Raises: ValueError: If the task is not a supervised learning task """ if not isinstance(task, SupervisedLearningTask): raise ValueError(f"Expected task to be a SupervisedLearningTask, got {type(task)}") if isinstance(model, GenerativeAdversarialNetworkModel): raise ValueError("Model is a GenerativeAdversarialNetworkModel, use `gan` trainer instead") if isinstance(task, (GenerativeAdversarialNetworkTask, GenerativeAdversarialNetworkRoundRobinTask)): raise ValueError("Task is a GenerativeAdversarialNetworkTask, use `gan` trainer instead") self._init_environment() with Timer("compiling model"): model = self._compile_model(model) with Timer("compiling training step"): train_step = self._compile_func(self.train_step) with Timer("compiling validation step"): val_step = self._compile_func(self.val_step) with Timer("building task model"): task_model = self._get_task_model(task, model) optim, lr_sched = self._get_optim_and_lr_sched(task_model, optimizer, lr_scheduler) state = self._get_state(task, model, optim, lr_sched) def on_exit(signum: int, _: FrameType | None) -> None: sig = signal.Signals(signum) self.on_exit(sig, state, task, model, optim, lr_sched) # Handle user-defined interrupts. signal.signal(signal.SIGUSR1, on_exit) # Gets the datasets. with Timer("getting datasets", 0.1): train_ds = task.get_dataset("train") valid_ds = task.get_dataset("valid") # Gets the dataloaders. with Timer("getting dataloaders", 0.1): train_dl = task.get_dataloader(train_ds, "train") valid_dl = task.get_dataloader(valid_ds, "valid") # Gets the prefetchers. with Timer("getting prefetchers", 0.1): train_pf = self._device.get_prefetcher(train_dl) valid_pf = self._device.get_prefetcher(valid_dl) valid_pf_iter = iter(InfinitePrefetcher(valid_pf)) self.on_training_start(state, task, model, optim, lr_sched) try: with contextlib.ExitStack() as ctx: profile = self.get_profile() if profile is not None: ctx.enter_context(profile) while True: with self.step_context("on_epoch_start"): self.on_epoch_start(state, task, model, optim, lr_sched) def batch_splitter() -> Iterator[Batch]: num_chunks = self.get_batch_chunks(state) for batch in train_pf: yield from recursive_chunk(batch, num_chunks, dim=self.config.batch_dim) train_pf_iter: Iterator = batch_splitter() def batch_iterator() -> Iterator[Batch]: try: yield next(train_pf_iter) except StopIteration: raise EpochDoneError for _ in range(self.get_batches_per_step(state) - 1): try: yield next(train_pf_iter) except StopIteration: pass while True: if self.should_validate(state): self._log_prefetcher_stats(valid_pf) val_step( task_model=task_model, batch=next(valid_pf_iter), state=state, task=task, model=model, ) self._log_prefetcher_stats(train_pf) if task.is_training_over(state): raise TrainingFinishedError with self.step_context("on_step_start"): self.on_step_start(state, task, model, optim, lr_sched) try: loss_dict = train_step( task_model=task_model, batches=batch_iterator(), state=state, task=task, model=model, optim=optim, lr_sched=lr_sched, ) except EpochDoneError: break if self.should_checkpoint(state): self.save_checkpoint(state, task, model, optim, lr_sched) if profile is not None: profile.step() with self.step_context("on_step_end"): self.on_step_end(state, loss_dict, task, model, optim, lr_sched) with self.step_context("on_epoch_end"): self.on_epoch_end(state, task, model, optim, lr_sched) except TrainingFinishedError: self.save_checkpoint(state, task, model, optim, lr_sched) logger.info( "Finished training after %d epochs, %d steps, %d samples", state.num_epochs, state.num_steps, state.num_samples, ) except Exception: logger.exception("Caught exception during training loop for %s", self.config_path) finally: self.on_training_end(state, task, model, optim, lr_sched)