Source code for ml.optimizers.common

"""Common optimizer utilities."""

from typing import Any, Iterable

from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm


[docs]def separate_decayable_params(model: nn.Module, default_decay: bool, weight_decay: float) -> Iterable[dict[str, Any]]: """Don't weight decay biases. This is mostly taken from nanoGPT. Args: model: The model to get the parameters for default_decay: Whether to decay by default (for modules which aren't explicitly specified) weight_decay: The weight decay to use Returns: The dictionary to pass to the optimizer """ wd_params: set[str] = set() no_wd_params: set[str] = set() seen: set[str] = set() always_decay = ( nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d, nn.MultiheadAttention, ) never_decay = ( _BatchNorm, nn.LocalResponseNorm, nn.GroupNorm, nn.LayerNorm, nn.Embedding, nn.EmbeddingBag, ) for mn, m in model.named_modules(): for pn, p in m.named_parameters(): fpn = f"{mn}.{pn}" if mn else pn if fpn in seen: continue seen.add(fpn) if p.ndim < 2: no_wd_params.add(fpn) elif isinstance(m, never_decay): no_wd_params.add(fpn) elif isinstance(m, always_decay): wd_params.add(fpn) else: (wd_params if default_decay else no_wd_params).add(fpn) param_dict = {pn: p for pn, p in model.named_parameters()} inter_params = wd_params & no_wd_params union_params = wd_params | no_wd_params assert len(inter_params) == 0, "Parameters made it into both decay and no-decay sets!" assert len(param_dict.keys() - union_params) == 0, "Parameters were not separated into decay or no-decay set!" return [ {"params": [param_dict[pn] for pn in sorted(list(wd_params))], "weight_decay": weight_decay}, {"params": [param_dict[pn] for pn in sorted(list(no_wd_params))], "weight_decay": 0.0}, ]
[docs]def can_use_fused(model: nn.Module) -> bool: return all(p.is_cuda and p.is_floating_point() for p in model.parameters())
[docs]def can_use_foreach(model: nn.Module) -> bool: return all(p.device.type in ("cpu", "cuda") and p.is_floating_point() for p in model.parameters())