Source code for ml.utils.device.cpu
"""CPU device type."""
from typing import Callable
import torch
from ml.utils.device.base import base_device
[docs]class cpu_device(base_device): # noqa: N801
"""Mixin to support CPU training."""
def _get_device(self) -> torch.device:
return torch.device("cpu")
def _get_floating_point_type(self) -> torch.dtype:
return torch.float32