Source code for ml.optimizers.base

"""Defines the base optimizer adapter.

This class usually just wraps PyTorch optimizers, providing some common
hyperparameter configurations.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, TypeVar

from torch import nn
from torch.optim.optimizer import Optimizer

from ml.core.config import BaseConfig, BaseObject

OptimizerT = TypeVar("OptimizerT", bound=Optimizer)


[docs]@dataclass class BaseOptimizerConfig(BaseConfig): """Defines the base config for all optimizers."""
OptimizerConfigT = TypeVar("OptimizerConfigT", bound=BaseOptimizerConfig)
[docs]class BaseOptimizer(BaseObject[OptimizerConfigT], Generic[OptimizerConfigT, OptimizerT], ABC): """Defines the base optimizer type.""" @property def common_kwargs(self) -> dict[str, Any]: return {}
[docs] @abstractmethod def get(self, model: nn.Module) -> OptimizerT: """Given a base module, returns an optimizer. Args: model: The model to get an optimizer for Returns: The constructed optimizer """