ml.optimizers.shampoo

Wrapper around the Shampoo optimizer.

This optimizer was proposed in Shampoo: Preconditioned Stochastic Tensor Optimization.

class ml.optimizers.shampoo.Shampoo(params: Iterable[Tensor] | Iterable[dict[str, Any]], lr: float = 0.1, momentum: float = 0.0, weight_decay: float = 0.0, epsilon: float = 0.0001, update_freq: int = 1)[source]

Bases: Optimizer

Implements Shampoo Optimizer Algorithm.

This is taken from the pytorch-optimizer package.

import torch_optimizer as optim
optimizer = optim.Shampoo(model.parameters(), lr=0.01)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

It has been proposed in Shampoo: Preconditioned Stochastic Tensor Optimization.

Note

This is not an implementation of the later paper, Scalable Second Order Optimization for Deep Learning, which is becoming more popular.

Parameters:
  • params – iterable of parameters to optimize or dicts defining parameter groups

  • lr – learning rate (default: 1e-3)

  • momentum – momentum factor (default: 0)

  • weight_decay – weight decay (L2 penalty) (default: 0)

  • epsilon – epsilon added to each mat_gbar_j for numerical stability (default: 1e-4)

  • update_freq – update frequency to compute inverse (default: 1)

step(closure: Callable[[], float] | None = None) float | None[source]

Performs a single optimization step.

Parameters:

closure – A closure that reevaluates the model and returns the loss.

Returns:

The total loss

class ml.optimizers.shampoo.ShampooOptimizerConfig(name: str = '???', lr: float = 0.001, momentum: float = 0.0, weight_decay: float = 0.0, epsilon: float = 0.0001, update_freq: int = 1, default_decay: bool = True)[source]

Bases: BaseOptimizerConfig

lr: float = 0.001
momentum: float = 0.0
weight_decay: float = 0.0
epsilon: float = 0.0001
update_freq: int = 1
default_decay: bool = True
class ml.optimizers.shampoo.ShampooOptimizer(config: BaseConfigT)[source]

Bases: BaseOptimizer[ShampooOptimizerConfig, Shampoo]

get(model: Module) Shampoo[source]

Given a base module, returns an optimizer.

Parameters:

model – The model to get an optimizer for

Returns:

The constructed optimizer