"""GPU device type.
The default floating point type can be configured with the environment
variables:
- ``USE_FP64``: Use FP64
- ``USE_FP32``: Use FP32
- ``USE_BF16``: Use BF16
"""
import functools
import logging
import os
from typing import Callable
import torch
from ml.core.env import is_gpu_disabled
from ml.utils.device.base import base_device
logger: logging.Logger = logging.getLogger(__name__)
[docs]def get_env_bool(key: str) -> bool:
val = int(os.environ.get(key, 0))
assert val in (0, 1), f"Invalid value for {key}: {val}"
return val == 1
[docs]class gpu_device(base_device): # noqa: N801
"""Mixin to support single-GPU training."""
[docs] @classmethod
def has_device(cls) -> bool:
return torch.cuda.is_available() and torch.cuda.device_count() > 0 and not is_gpu_disabled()
@functools.lru_cache(maxsize=None)
def _get_device(self) -> torch.device:
return torch.device("cuda")
@functools.lru_cache(maxsize=None)
def _get_floating_point_type(self) -> torch.dtype:
# Allows users to override the default floating point type.
if get_env_bool("USE_FP64"):
return torch.float64
elif get_env_bool("USE_FP32"):
return torch.float32
elif get_env_bool("USE_BF16"):
return torch.bfloat16
elif get_env_bool("USE_FP16"):
return torch.float16
# By default, use BF16 if the GPU supports it, otherwise FP16.
if torch.cuda.get_device_capability()[0] >= 8:
return torch.bfloat16
return torch.float16
[docs] def get_torch_compile_backend(self) -> str | Callable:
capability = torch.cuda.get_device_capability()
if capability >= (7, 0):
return "inductor"
return "aot_ts_nvfuser"
[docs] def supports_grad_scaler(self) -> bool:
return self._get_floating_point_type() not in (torch.float32, torch.float64)