ml.optimizers.lion
Wrapper around the Lion optimizer.
This optimizer was proposed in Symbolic Discovery of Optimization Algorithms.
Lion stands for “Evolved Sign Momentum” (yes, actually). It is more memory-efficient than Adam since it only keeps track of momentum.
In the original paper, the authors suggest using a larger batch size and a smaller learning rate compared to Adam.
This optimizer shines for tasks like contrasitve learning and diffusion which optimize proxy objectives rather than doing something like cross-entropy classification, although in the paper the authors show that it performs comparably to Adam on language modeling.
This implementation is based on the lucidrain's
implementation
here and on the pseudo-code
from the paper, which is reproduced below:
def train(weight, gradient, momentum, lr):
update = interp(gradient, momentum, beta1)
update = sign(update)
momentum = interp(gradient, momentum, beta2)
update = update + weight * weight_deacy
update = update * lr
return update, momentum
- ml.optimizers.lion.get_update_fn(cpu: bool) Callable[[Parameter, Tensor, Tensor, float, float, float, float], None] [source]
- class ml.optimizers.lion.Lion(params: Iterable[Tensor] | Iterable[dict[str, Any]], lr: float = 0.0001, betas: tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, use_triton: bool = False)[source]
Bases:
Optimizer
- step(closure: Callable[[], float] | None = None) float | None [source]
Performs a single optimization step (parameter update).
- Parameters:
closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.
Note
Unless otherwise specified, this function should not modify the
.grad
field of the parameters.
- class ml.optimizers.lion.LionOptimizerConfig(name: str = '???', lr: float = 0.0001, betas: tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.01, default_decay: bool = True, use_triton: bool = True)[source]
Bases:
BaseOptimizerConfig
- lr: float = 0.0001
- betas: tuple[float, float] = (0.9, 0.99)
- weight_decay: float = 0.01
- default_decay: bool = True
- use_triton: bool = True
- class ml.optimizers.lion.LionOptimizer(config: BaseConfigT)[source]
Bases:
BaseOptimizer
[LionOptimizerConfig
,Lion
]