Source code for ml.optimizers.adam

"""Wrapper around the PyTorch Adam / AdamW optimizer.

With weight decay greater than 0 (which is the default), uses the AdamW variant
of the optimizer.
"""

from dataclasses import dataclass

from torch import nn
from torch.optim.adam import Adam
from torch.optim.adamw import AdamW

from ml.core.config import conf_field
from ml.core.registry import register_optimizer
from ml.optimizers.base import BaseOptimizer, BaseOptimizerConfig
from ml.optimizers.common import can_use_foreach, can_use_fused, separate_decayable_params


[docs]@dataclass class AdamOptimizerConfig(BaseOptimizerConfig): lr: float = conf_field(3e-4, help="Learning rate") betas: tuple[float, float] = conf_field((0.9, 0.999), help="Beta coefficients") eps: float = conf_field(1e-5, help="Epsilon term to add to the denominator for stability") weight_decay: float = conf_field(1e-5, help="Weight decay regularization to use") amsgrad: bool = conf_field(False, help="Whether to use the AMSGrad variant of the algorithm") default_decay: bool = conf_field(True, help="Whether to decay module params which aren't explicitly specified") foreach: bool | None = conf_field(None, help="Whether to use the foreach variant of the optimizer") capturable: bool = conf_field(False, help="Whether to use capturable AdamW pathway") differentiable: bool = conf_field(False, help="Whether to use differentiable AdamW") fused: bool | None = conf_field(None, help="Whether to use the fused optimizer")
[docs] @classmethod def get_defaults(cls) -> dict[str, "AdamOptimizerConfig"]: return { "gpt-3-small": AdamOptimizerConfig( lr=6e-4, betas=(0.9, 0.95), weight_decay=0.1, ), "gpt-3-medium": AdamOptimizerConfig( lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1, ), "gpt-3-large": AdamOptimizerConfig( lr=2.5e-4, betas=(0.9, 0.95), weight_decay=0.1, ), "roberta-base": AdamOptimizerConfig( lr=6e-4, betas=(0.9, 0.98), weight_decay=0.01, ), "roberta-large": AdamOptimizerConfig( lr=4e-4, betas=(0.9, 0.98), weight_decay=0.01, ), }
[docs]@register_optimizer("adam", AdamOptimizerConfig) class AdamOptimizer(BaseOptimizer[AdamOptimizerConfig, Adam | AdamW]):
[docs] def get(self, model: nn.Module) -> Adam | AdamW: b1, b2 = self.config.betas # Chooses reasonable defaults for foreach and fused variants. fused, foreach = self.config.fused, self.config.foreach if foreach is not None and fused is not None: assert not (foreach and fused), "Cannot use both foreach and fused variants of Adam" if foreach is None and fused is None: if not self.config.differentiable and can_use_fused(model): fused = True elif can_use_foreach(model): foreach = True if fused is None: fused = False if foreach is None: foreach = False if self.config.weight_decay > 0.0: return AdamW( separate_decayable_params(model, self.config.default_decay, self.config.weight_decay), lr=self.config.lr, betas=(b1, b2), eps=self.config.eps, amsgrad=self.config.amsgrad, foreach=foreach, capturable=self.config.capturable, differentiable=self.config.differentiable, fused=fused, **self.common_kwargs, ) return Adam( model.parameters(), lr=self.config.lr, betas=(b1, b2), eps=self.config.eps, amsgrad=self.config.amsgrad, foreach=foreach, capturable=self.config.capturable, differentiable=self.config.differentiable, fused=fused, **self.common_kwargs, )