Source code for ml.optimizers.adan

"""Wrapper around the PyTorch Adan optimizer."""

from dataclasses import dataclass
from typing import Any, Callable, Iterable

import torch
from torch import nn
from torch.optim import Optimizer

from ml.core.config import conf_field
from ml.core.registry import register_optimizer
from ml.optimizers.base import BaseOptimizer, BaseOptimizerConfig
from ml.optimizers.common import separate_decayable_params


[docs]class Adan(Optimizer): def __init__( self, params: Iterable[nn.Parameter] | Iterable[dict[str, Any]], lr: float = 1e-3, betas: tuple[float, float, float] = (0.1, 0.1, 0.001), eps: float = 1e-8, weight_decay: float = 0.0, ) -> None: assert len(betas) == 3 defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay} super().__init__(params, defaults)
[docs] @torch.no_grad() def step(self, closure: Callable[[], float] | None = None) -> float | None: # type: ignore[override] loss: float | None = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: lr = group["lr"] beta1, beta2, beta3 = group["betas"] weight_decay = group["weight_decay"] eps = group["eps"] for p in group["params"]: if p.grad is None: continue data, grad = p.data, p.grad.data assert not grad.is_sparse state = self.state[p] if len(state) == 0: state["step"] = 0 state["m"] = grad.clone() state["v"] = torch.zeros_like(grad) state["n"] = grad**2 step, m, v, n = state["step"], state["m"], state["v"], state["n"] zeroth_step = step == 0 first_step = step == 1 if not zeroth_step: prev_grad = state["prev_grad"] m.mul_(1 - beta1).add_(grad, alpha=beta1) grad_diff = grad - prev_grad if not first_step: v.mul_(1 - beta2).add_(grad_diff, alpha=beta2) else: v.add_(grad_diff) next_n = (grad + (1 - beta2) * grad_diff) ** 2 n.mul_(1 - beta3).add_(next_n, alpha=beta3) weighted_step_size = lr / (n + eps).sqrt() denom = 1 + weight_decay * lr data.addcmul_(weighted_step_size, (m + (1 - beta2) * v), value=-1.0).div_(denom) state["prev_grad"] = grad.clone() state["step"] += 1 return loss
[docs]@dataclass class AdanOptimizerConfig(BaseOptimizerConfig): lr: float = conf_field(1e-3, help="Learning rate") betas: tuple[float, float, float] = conf_field((0.1, 0.1, 0.001), help="Beta coefficients") eps: float = conf_field(1e-4, help="Epsilon term") weight_decay: float = conf_field(1e-5, help="Weight decay regularization to use") default_decay: bool = conf_field(True, help="Whether to decay module params which aren't explicitly specified")
[docs]@register_optimizer("adan", AdanOptimizerConfig) class AdanOptimizer(BaseOptimizer[AdanOptimizerConfig, Adan]):
[docs] def get(self, model: nn.Module) -> Adan: b1, b2, b3 = self.config.betas return Adan( separate_decayable_params(model, self.config.default_decay, self.config.weight_decay), lr=self.config.lr, betas=(b1, b2, b3), eps=self.config.eps, **self.common_kwargs, )