"""Common optimizer utilities."""
from typing import Any, Iterable
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm
[docs]def separate_decayable_params(model: nn.Module, default_decay: bool, weight_decay: float) -> Iterable[dict[str, Any]]:
"""Don't weight decay biases.
This is mostly taken from nanoGPT.
Args:
model: The model to get the parameters for
default_decay: Whether to decay by default (for modules which aren't
explicitly specified)
weight_decay: The weight decay to use
Returns:
The dictionary to pass to the optimizer
"""
wd_params: set[str] = set()
no_wd_params: set[str] = set()
seen: set[str] = set()
always_decay = (
nn.Linear,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.MultiheadAttention,
)
never_decay = (
_BatchNorm,
nn.LocalResponseNorm,
nn.GroupNorm,
nn.LayerNorm,
nn.Embedding,
nn.EmbeddingBag,
)
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
fpn = f"{mn}.{pn}" if mn else pn
if fpn in seen:
continue
seen.add(fpn)
if p.ndim < 2:
no_wd_params.add(fpn)
elif isinstance(m, never_decay):
no_wd_params.add(fpn)
elif isinstance(m, always_decay):
wd_params.add(fpn)
else:
(wd_params if default_decay else no_wd_params).add(fpn)
param_dict = {pn: p for pn, p in model.named_parameters()}
inter_params = wd_params & no_wd_params
union_params = wd_params | no_wd_params
assert len(inter_params) == 0, "Parameters made it into both decay and no-decay sets!"
assert len(param_dict.keys() - union_params) == 0, "Parameters were not separated into decay or no-decay set!"
return [
{"params": [param_dict[pn] for pn in sorted(list(wd_params))], "weight_decay": weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_wd_params))], "weight_decay": 0.0},
]
[docs]def can_use_fused(model: nn.Module) -> bool:
return all(p.is_cuda and p.is_floating_point() for p in model.parameters())
[docs]def can_use_foreach(model: nn.Module) -> bool:
return all(p.device.type in ("cpu", "cuda") and p.is_floating_point() for p in model.parameters())