Source code for ml.optimizers.gan

"""GAN model optimizer wrapper.

This wrapper allows downstream users to set different optimizers or optimizer
parameters for the generator and discriminator of a GAN.

This class is used by the GAN trainer interface and shouldn't be used elsewhere.
"""

from dataclasses import dataclass
from typing import Any

from omegaconf import MISSING, DictConfig
from torch import nn
from torch.optim.optimizer import Optimizer

from ml.core.config import conf_field
from ml.core.registry import register_optimizer
from ml.optimizers.base import BaseOptimizer, BaseOptimizerConfig


[docs]@dataclass class GenerativeAdversarialNetworkOptimizerConfig(BaseOptimizerConfig): generator: Any = conf_field(MISSING, help="The generator optimizer to use") discriminator: Any = conf_field(MISSING, help="The discriminator optimizer to use")
[docs] @classmethod def update(cls: type["GenerativeAdversarialNetworkOptimizerConfig"], config: DictConfig) -> DictConfig: config = super().update(config) assert (gen_name := config.generator.get("name")) is not None, "The generator name must be specified" assert (dis_name := config.discriminator.get("name")) is not None, "The discriminator name must be specified" _, gen_cfg_cls = register_optimizer.lookup(gen_name) config.generator = gen_cfg_cls.update(config.generator) _, dis_cfg_cls = register_optimizer.lookup(dis_name) config.discriminator = dis_cfg_cls.update(config.discriminator) return config
[docs] @classmethod def resolve( cls: type["GenerativeAdversarialNetworkOptimizerConfig"], config: "GenerativeAdversarialNetworkOptimizerConfig", ) -> None: _, gen_cfg_cls = register_optimizer.lookup(config.generator.name) gen_cfg_cls.resolve(config.generator) _, dis_cfg_cls = register_optimizer.lookup(config.discriminator.name) dis_cfg_cls.resolve(config.discriminator)
[docs]@register_optimizer("gan", GenerativeAdversarialNetworkOptimizerConfig) class GenerativeAdversarialNetworkOptimizer(BaseOptimizer[GenerativeAdversarialNetworkOptimizerConfig, Optimizer]): def __init__(self, config: GenerativeAdversarialNetworkOptimizerConfig) -> None: super().__init__(config) gen_cls, _ = register_optimizer.lookup(config.generator.name) self.generator = gen_cls(config.generator) dis_cls, _ = register_optimizer.lookup(config.discriminator.name) self.discriminator = dis_cls(config.discriminator)
[docs] def get(self, model: nn.Module) -> Optimizer: raise NotImplementedError("This method shouldn't be called directly.")