Source code for ml.models.base

"""The base object and config for all models.

This is essentially just a small wrapper around a vanilla PyTorch module.
"""

import logging
from dataclasses import dataclass
from typing import Generic, TypeVar

import torch
from torch import Tensor, nn

from ml.core.config import BaseConfig, BaseObject
from ml.loggers.multi import MultiLogger
from ml.utils.colors import colorize
from ml.utils.device.base import allow_nonblocking

logger: logging.Logger = logging.getLogger(__name__)


[docs]@dataclass class BaseModelConfig(BaseConfig): """Defines the base config for all modules."""
ModelConfigT = TypeVar("ModelConfigT", bound=BaseModelConfig)
[docs]def summarize(names: list[tuple[str, torch.device]]) -> str: return "".join(f"\n{colorize(k, 'red')} - {device}" for k, device in names)
[docs]class BaseModel(BaseObject[ModelConfigT], Generic[ModelConfigT], nn.Module): """Defines the base module type.""" __constants__ = ["config"] def __init__(self, config: ModelConfigT) -> None: nn.Module.__init__(self) BaseObject.__init__(self, config) # Used to log values to the trainer. self.logger = MultiLogger(default_namespace="model")
[docs] def init(self, device: torch.device, dtype: torch.dtype | None = None) -> None: # Moves all non-meta tensors to the first device. def move_to_device(t: Tensor) -> Tensor: if t.is_meta: return t if t.is_floating_point(): return t.to(device=device, dtype=dtype, non_blocking=allow_nonblocking(device, t.device)) return t.to(device=device, non_blocking=allow_nonblocking(device, t.device)) self._apply(move_to_device) bad_params = {(name, p.device) for name, p in self.named_parameters() if p.device.type != device.type} if bad_params: bad_param_names = sorted(list(bad_params))[:5] logger.warning( "Got %d params which are on a different device from %s. First %d:%s", len(bad_params), device, len(bad_param_names), summarize(bad_param_names), ) bad_buffers = {(name, b.device) for name, b in self.named_buffers() if b.device.type != device.type} if bad_buffers: bad_buffer_names = sorted(list(bad_buffers))[:5] logger.warning( "Got %d buffers which are on a different device from %s. First %d:\n%s", len(bad_buffers), device, len(bad_buffer_names), summarize(bad_buffer_names), )
[docs] @torch.jit.ignore def get_device(self) -> torch.device: return next(self.parameters()).device
[docs] @torch.jit.ignore def get_dtype(self) -> torch.dtype: return next(p for p in self.parameters() if p.is_floating_point()).dtype
[docs] @torch.jit.ignore def tensor_to(self, tensor: Tensor, non_blocking: bool = False) -> Tensor: device, dtype = self.get_device(), self.get_dtype() if tensor.is_floating_point() or tensor.is_complex(): return tensor.to(device, dtype, non_blocking=non_blocking and allow_nonblocking(device, tensor.device)) return tensor.to(device, non_blocking=non_blocking and allow_nonblocking(device, tensor.device))