ml.optimizers.adam

Wrapper around the PyTorch Adam / AdamW optimizer.

With weight decay greater than 0 (which is the default), uses the AdamW variant of the optimizer.

class ml.optimizers.adam.AdamOptimizerConfig(name: str = '???', lr: float = 0.0003, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-05, weight_decay: float = 1e-05, amsgrad: bool = False, default_decay: bool = True, foreach: bool | None = None, capturable: bool = False, differentiable: bool = False, fused: bool | None = None)[source]

Bases: BaseOptimizerConfig

lr: float = 0.0003
betas: tuple[float, float] = (0.9, 0.999)
eps: float = 1e-05
weight_decay: float = 1e-05
amsgrad: bool = False
default_decay: bool = True
foreach: bool | None = None
capturable: bool = False
differentiable: bool = False
fused: bool | None = None
classmethod get_defaults() dict[str, 'AdamOptimizerConfig'][source]

Returns default configurations.

Returns:

A dictionary of default configurations for the current config

class ml.optimizers.adam.AdamOptimizer(config: BaseConfigT)[source]

Bases: BaseOptimizer[AdamOptimizerConfig, Adam | AdamW]

get(model: Module) Adam | AdamW[source]

Given a base module, returns an optimizer.

Parameters:

model – The model to get an optimizer for

Returns:

The constructed optimizer