"""Defines functions used for mixed precision training."""
from typing import Iterable, cast
import torch
from torch import Tensor, inf, nn
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype, _has_foreach_support
GradDict = dict[tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]]]
[docs]@torch.no_grad()
def get_weight_norm(
parameters: Iterable[nn.Parameter],
norm_type: float = 2.0,
foreach: bool | None = None,
) -> Tensor:
"""Computes the norm of an iterable of parameters.
The norm is computed over all parameters together, as if they were
concatenated into a single vector.
Args:
parameters: An iterable of the model parameters.
norm_type: The type of the used p-norm.
foreach: Use the faster foreach-based implementation.
Returns:
The total norm of the parameters (viewed as a single vector).
"""
parameters = list(parameters)
if len(parameters) == 0:
return torch.tensor([0.0])
first_device = parameters[0].device
grouped_params = cast(GradDict, _group_tensors_by_device_and_dtype([[p.detach() for p in parameters]]))
if norm_type == inf:
norms = [p.detach().abs().max().to(first_device) for p in parameters]
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
else:
norms = []
for (device, _), ([param], _) in grouped_params.items():
if (foreach is None or foreach) and _has_foreach_support(param, device=device):
norms.extend(torch._foreach_norm(param, norm_type))
else:
norms.extend([torch.norm(g, norm_type) for g in param])
total_norm = torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
return total_norm
[docs]@torch.no_grad()
def get_grad_norm(
parameters: Iterable[nn.Parameter],
norm_type: float = 2.0,
foreach: bool | None = None,
) -> tuple[Tensor, GradDict]:
grads = [p.grad for p in parameters if p.grad is not None]
if len(grads) == 0:
return torch.tensor([0.0]), {}
first_device = grads[0].device
grouped_grads = cast(GradDict, _group_tensors_by_device_and_dtype([[g.detach() for g in grads]]))
if norm_type == inf:
norms = [g.detach().abs().max().to(first_device) for g in grads]
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
else:
norms = []
for (device, _), ([grads], _) in grouped_grads.items():
if (foreach is None or foreach) and _has_foreach_support(grads, device=device):
norms.extend(torch._foreach_norm(grads, norm_type))
else:
norms.extend([torch.norm(g, norm_type) for g in grads])
total_norm = torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
return total_norm, grouped_grads
[docs]@torch.no_grad()
def clip_grad_norm_(
parameters: Iterable[nn.Parameter],
max_norm: float,
norm_type: float = 2.0,
foreach: bool | None = None,
) -> tuple[Tensor, bool]:
"""Clips gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
Args:
parameters: An iterable of the model parameters.
max_norm: The maximum norm of the gradients.
norm_type: The type of the used p-norm.
foreach: Use the faster foreach-based implementation. If ``None``, use
the foreach implementation for CUDA and CPU native tensors and
silently fall back to the slow implementation for other device
types. If ``True`` or ``False``, use the foreach or non-foreach
implementation, respectively, and raise an error if the chosen
implementation is not available.
Returns:
The total norm of the parameters (viewed as a single vector) and
whether the parameters were successfully clipped.
"""
total_norm, grouped_grads = get_grad_norm(parameters, norm_type, foreach)
if not torch.isfinite(total_norm):
return total_norm, False
clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for (device, _), ([grads], _) in grouped_grads.items():
if (foreach is None or foreach) and _has_foreach_support(grads, device=device):
torch._foreach_mul_(grads, clip_coef_clamped.to(device))
else:
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in grads:
g.detach().mul_(clip_coef_clamped_device)
return total_norm, True