Source code for ml.utils.device.cpu
"""CPU device type."""
from typing import Callable
import torch
from ml.utils.device.base import base_device
[docs]class cpu_device(base_device):  # noqa: N801
    """Mixin to support CPU training."""
    def _get_device(self) -> torch.device:
        return torch.device("cpu")
    def _get_floating_point_type(self) -> torch.dtype:
        return torch.float32