Source code for ml.tasks.gan.base

"""Defines the base GAN task type.

This class expects you to implement the following functions:

.. code-block:: python

    class MyGanTask(
        ml.GenerativeAdversarialNetworkTask[
            Config,
            Generator,
            Discriminator,
            Batch,
            GeneratorOutput,
            DiscriminatorOutput,
            Loss,
        ],
    ):
        def run_generator(self, model: Generator, batch: Batch, state: ml.State) -> GeneratorOutput:
            ...

        def run_discriminator(
            self,
            model: Discriminator,
            batch: Batch,
            gen_output: GeneratorOutput,
            state: ml.State,
        ) -> DiscriminatorOutput:
            ...

        def compute_discriminator_loss(
            self,
            generator: Generator,
            discriminator: Discriminator,
            batch: Batch,
            state: ml.State,
            gen_output: GeneratorOutput,
            dis_output: DiscriminatorOutput,
        ) -> Loss:
            ...

        def get_dataset(self, phase: ml.Phase) -> Dataset:
            ...
"""

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, TypeVar

from torch import Tensor

from ml.core.common_types import Batch
from ml.core.state import State
from ml.models.gan import DiscriminatorT, GenerativeAdversarialNetworkModel, GeneratorT
from ml.tasks.sl.base import SupervisedLearningTask, SupervisedLearningTaskConfig

logger: logging.Logger = logging.getLogger(__name__)

GeneratorOutput = TypeVar("GeneratorOutput")
DiscriminatorOutput = TypeVar("DiscriminatorOutput")


[docs]@dataclass class GenerativeAdversarialNetworkTaskConfig(SupervisedLearningTaskConfig): pass
GenerativeAdversarialNetworkTaskConfigT = TypeVar( "GenerativeAdversarialNetworkTaskConfigT", bound=GenerativeAdversarialNetworkTaskConfig, )
[docs]class GenerativeAdversarialNetworkTask( SupervisedLearningTask[ GenerativeAdversarialNetworkTaskConfigT, GenerativeAdversarialNetworkModel[GeneratorT, DiscriminatorT], Batch, tuple[GeneratorOutput, DiscriminatorOutput], dict[str, Tensor], ], Generic[ GenerativeAdversarialNetworkTaskConfigT, GeneratorT, DiscriminatorT, Batch, GeneratorOutput, DiscriminatorOutput, ], ABC, ):
[docs] @abstractmethod def run_generator(self, generator: GeneratorT, batch: Batch, state: State) -> GeneratorOutput: """Runs the generator model on the given batch. Args: generator: The generator module. batch: The batch to run the model on. state: The current training state. Returns: The output of the generator model """
[docs] @abstractmethod def run_discriminator( self, discriminator: DiscriminatorT, batch: Batch, gen_output: GeneratorOutput, state: State, ) -> DiscriminatorOutput: """Runs the discriminator model on the given batch. Args: discriminator: The discriminator model. batch: The batch to run the model on. gen_output: The output of the generator model. state: The current training state. Returns: The output of the discriminator model """
[docs] @abstractmethod def compute_discriminator_loss( self, generator: GeneratorT, discriminator: DiscriminatorT, batch: Batch, state: State, gen_output: GeneratorOutput, dis_output: DiscriminatorOutput, ) -> dict[str, Tensor]: """Computes the discriminator loss for the given batch. Args: generator: The generator model. discriminator: The discriminator model. batch: The batch to run the model on. state: The current training state. gen_output: The output of the generator model. dis_output: The output of the discriminator model. Returns: The discriminator loss. """
[docs] def compute_generator_loss( self, generator: GeneratorT, discriminator: DiscriminatorT, batch: Batch, state: State, gen_output: GeneratorOutput, dis_output: DiscriminatorOutput, ) -> dict[str, Tensor]: loss = self.compute_discriminator_loss(generator, discriminator, batch, state, gen_output, dis_output) return {k: -v for k, v in loss.items()}
[docs] def do_logging( self, generator: GeneratorT, discriminator: DiscriminatorT, batch: Batch, state: State, gen_output: GeneratorOutput, dis_output: DiscriminatorOutput, losses: dict[str, Tensor], ) -> None: """Override this method to perform any logging. This will avoid some annoying context manager issues. """
[docs] def run_model( self, model: GenerativeAdversarialNetworkModel[GeneratorT, DiscriminatorT], batch: Batch, state: State, ) -> tuple[GeneratorOutput, DiscriminatorOutput]: gen_model, dis_model = model.generator, model.discriminator generator_output = self.run_generator(gen_model, batch, state) discriminator_output = self.run_discriminator(dis_model, batch, generator_output, state) return generator_output, discriminator_output
[docs] def compute_loss( self, model: GenerativeAdversarialNetworkModel[GeneratorT, DiscriminatorT], batch: Batch, state: State, output: tuple[GeneratorOutput, DiscriminatorOutput], ) -> dict[str, Tensor]: gen_model, dis_model = model.generator, model.discriminator gen_output, dis_output = output gen_losses = self.compute_generator_loss(gen_model, dis_model, batch, state, gen_output, dis_output) dis_losses = self.compute_discriminator_loss(gen_model, dis_model, batch, state, gen_output, dis_output) losses = {**{f"gen/{k}": v for k, v in gen_losses.items()}, **{f"dis/{k}": v for k, v in dis_losses.items()}} self.do_logging(gen_model, dis_model, batch, state, gen_output, dis_output, losses) return losses
[docs] def separate_losses(self, losses: dict[str, Tensor]) -> tuple[dict[str, Tensor], dict[str, Tensor]]: gen_losses, dis_losses = {}, {} for k, v in losses.items(): if k.startswith("gen/"): gen_losses[k] = v elif k.startswith("dis/"): dis_losses[k] = v else: raise ValueError(f"Invalid loss key: {k}") return gen_losses, dis_losses
# ----- # Hooks # -----
[docs] def on_after_gan_forward_step( self, generator: GeneratorT, discriminator: DiscriminatorT, batch: Batch, state: State, gen_output: GeneratorOutput, dis_output: DiscriminatorOutput, ) -> None: """GAN-specific hook that is called after a forward step. This is useful for implementing the Wasserstein GAN gradient penalty. Args: generator: The generator model. discriminator: The discriminator model. batch: The batch to run the model on. state: The current training state. gen_output: The output of the generator model. dis_output: The output of the discriminator model. """
[docs] def on_after_forward_step( self, model: GenerativeAdversarialNetworkModel[GeneratorT, DiscriminatorT], batch: Batch, output: tuple[GeneratorOutput, DiscriminatorOutput], state: State, ) -> None: super().on_after_forward_step(model, batch, output, state) self.on_after_gan_forward_step(model.generator, model.discriminator, batch, state, output[0], output[1])