Source code for ml.optimizers.lion

"""Wrapper around the Lion optimizer.

This optimizer was proposed in `Symbolic Discovery of Optimization Algorithms
<https://arxiv.org/abs/2302.06675>`_.

Lion stands for "Evolved Sign Momentum" (yes, actually). It is more
memory-efficient than Adam since it only keeps track of momentum.

In the original paper, the authors suggest using a larger batch size and a
smaller learning rate compared to Adam.

This optimizer shines for tasks like contrasitve learning and diffusion which
optimize proxy objectives rather than doing something like cross-entropy
classification, although in the paper the authors show that it performs
comparably to Adam on language modeling.

This implementation is based on the ``lucidrain's`` implementation
`here <https://github.com/lucidrains/lion-pytorch/>`_ and on the pseudo-code
from the paper, which is reproduced below:

.. highlight:: python
.. code-block:: python

    def train(weight, gradient, momentum, lr):
        update = interp(gradient, momentum, beta1)
        update = sign(update)
        momentum = interp(gradient, momentum, beta2)
        update = update + weight * weight_deacy
        update = update * lr
        return update, momentum
"""

import logging
from dataclasses import dataclass
from typing import Callable, cast

import torch
from torch import Tensor, nn
from torch.optim.optimizer import Optimizer

from ml.core.config import conf_field
from ml.core.registry import register_optimizer
from ml.optimizers.base import BaseOptimizer, BaseOptimizerConfig
from ml.optimizers.common import separate_decayable_params
from ml.optimizers.types import Params
from ml.utils.triton import supports_triton

logger = logging.getLogger(__name__)


def _update_fn_vanilla(
    p: nn.Parameter,
    grad: Tensor,
    exp_avg: Tensor,
    lr: float,
    wd: float,
    beta1: float,
    beta2: float,
) -> None:
    """Runs the update function for a given parameter.

    This can be made slightly faster using Triton, if GPU acceleration is
    available. Make sure Triton is installed and set ``use_triton=True`` in
    the optimizer configuration.

    Args:
        p: Parameter to update.
        grad: Gradient for the parameter.
        exp_avg: Exponential average of the gradient.
        lr: Learning rate.
        wd: Weight decay.
        beta1: First momentum coefficient.
        beta2: Second momentum coefficient.
    """
    update = exp_avg.clone().mul_(beta1).add(grad, alpha=1 - beta1).sign_()
    p.data.mul_(1 - lr * wd).add_(update, alpha=-lr)
    exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)


[docs]def get_update_fn(cpu: bool) -> Callable[[nn.Parameter, Tensor, Tensor, float, float, float, float], None]: if cpu or not supports_triton(): return _update_fn_vanilla from ml.utils.triton.lion import update_fn as triton_update_fn return triton_update_fn
[docs]class Lion(Optimizer): def __init__( self, params: Params, lr: float = 1e-4, betas: tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, use_triton: bool = False, ) -> None: if lr <= 0.0: raise ValueError(f"Invalid learning rate: {lr}") if not all([0.0 <= beta <= 1.0 for beta in betas]): raise ValueError(f"Invalid beta: {betas}") defaults = { "lr": lr, "betas": betas, "weight_decay": weight_decay, } super().__init__(params, defaults) self.update_fn = get_update_fn(True) self.update_fn_cuda = get_update_fn(False)
[docs] @torch.no_grad() def step(self, closure: Callable[[], float] | None = None) -> float | None: # type: ignore[override] loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group["params"]: p = cast(Tensor, p) if p.grad is None: continue grad = p.grad.data lr: float = group["lr"] wd: float = group["weight_decay"] beta1, beta2 = group["betas"] state = self.state[p] if len(state) == 0: state["exp_avg"] = torch.zeros_like(p) update_fn = self.update_fn_cuda if grad.is_cuda else self.update_fn update_fn(p, grad, state["exp_avg"], lr, wd, beta1, beta2) return loss
[docs]@dataclass class LionOptimizerConfig(BaseOptimizerConfig): lr: float = conf_field(1e-4, help="Learning rate.") betas: tuple[float, float] = conf_field((0.9, 0.99), help="Beta coefficients.") weight_decay: float = conf_field(1e-2, help="Weight decay.") default_decay: bool = conf_field(True, help="Whether to decay module params which aren't explicitly specified") use_triton: bool = conf_field(True, help="Whether to use Triton for faster updates.")
[docs] @classmethod def get_defaults(cls) -> dict[str, "LionOptimizerConfig"]: return { "lion-stable": LionOptimizerConfig( betas=(0.95, 0.98), ), }
[docs]@register_optimizer("lion", LionOptimizerConfig) class LionOptimizer(BaseOptimizer[LionOptimizerConfig, Lion]):
[docs] def get(self, model: nn.Module) -> Lion: return Lion( separate_decayable_params(model, self.config.default_decay, self.config.weight_decay), lr=self.config.lr, betas=self.config.betas, use_triton=self.config.use_triton, )