ml.tasks.gan.round_robin

Defines the base GAN task type.

This class expects you to implement the following functions:

class MyGanTask(
    ml.GenerativeAdversarialNetworkRoundRobinTask[
        Config,
        Generator,
        Discriminator,
        Batch,
        GeneratorOutput,
        DiscriminatorOutput,
        Loss,
    ],
):
    def run_generator(self, model: Generator, batch: Batch, state: ml.State) -> GeneratorOutput:
        ...

    def run_discriminator(
        self,
        model: Discriminator,
        batch: Batch,
        gen_output: GeneratorOutput,
        state: ml.State,
    ) -> DiscriminatorOutput:
        ...

    def compute_discriminator_loss(
        self,
        generator: Generator,
        discriminator: Discriminator,
        batch: Batch,
        state: ml.State,
        gen_output: GeneratorOutput,
        dis_output: DiscriminatorOutput,
    ) -> Loss:
        ...

    def get_dataset(self, phase: ml.Phase) -> Dataset:
        ...
class ml.tasks.gan.round_robin.GenerativeAdversarialNetworkRoundRobinTaskConfig(max_epochs: int | None = None, max_steps: int | None = None, max_samples: int | None = None, max_seconds: float | None = None, train_dl: ml.tasks.base.DataLoaderConfig = <factory>, valid_dl: ml.tasks.base.DataLoaderConfig = <factory>, test_dl: ml.tasks.base.DataLoaderConfig = <factory>, name: str = '???', errors: ml.tasks.datasets.error_handling.ErrorHandlingConfig = <factory>, generator_steps: int = 1, discriminator_steps: int = 1)[source]

Bases: SupervisedLearningTaskConfig

generator_steps: int = 1
discriminator_steps: int = 1
class ml.tasks.gan.round_robin.GenerativeAdversarialNetworkRoundRobinTask(config: BaseTaskConfigT)[source]

Bases: SupervisedLearningTask[GenerativeAdversarialNetworkTaskRoundRobinConfigT, GenerativeAdversarialNetworkModel[GeneratorT, DiscriminatorT], Batch, tuple[GeneratorOutput, DiscriminatorOutput], dict[str, Tensor]], Generic[GenerativeAdversarialNetworkTaskRoundRobinConfigT, GeneratorT, DiscriminatorT, Batch, GeneratorOutput, DiscriminatorOutput], ABC

Initializes internal Module state, shared by both nn.Module and ScriptModule.

abstract run_generator(generator: GeneratorT, batch: Batch, state: State) GeneratorOutput[source]

Runs the generator model on the given batch.

Parameters:
  • generator – The generator module.

  • batch – The batch to run the model on.

  • state – The current training state.

Returns:

The output of the generator model

abstract run_discriminator(discriminator: DiscriminatorT, batch: Batch, gen_output: GeneratorOutput, state: State) DiscriminatorOutput[source]

Runs the discriminator model on the given batch.

Parameters:
  • discriminator – The discriminator model.

  • batch – The batch to run the model on.

  • gen_output – The output of the generator model.

  • state – The current training state.

Returns:

The output of the discriminator model

abstract compute_discriminator_loss(generator: GeneratorT, discriminator: DiscriminatorT, batch: Batch, state: State, gen_output: GeneratorOutput, dis_output: DiscriminatorOutput) dict[str, torch.Tensor][source]

Computes the discriminator loss for the given batch.

Parameters:
  • generator – The generator model.

  • discriminator – The discriminator model.

  • batch – The batch to run the model on.

  • state – The current training state.

  • gen_output – The output of the generator model.

  • dis_output – The output of the discriminator model.

Returns:

The discriminator loss.

compute_generator_loss(generator: GeneratorT, discriminator: DiscriminatorT, batch: Batch, state: State, gen_output: GeneratorOutput, dis_output: DiscriminatorOutput) dict[str, torch.Tensor][source]
is_generator_step(state: State, phase: Literal['train', 'valid', 'test'] | None = None) bool[source]
do_logging(generator: GeneratorT, discriminator: DiscriminatorT, batch: Batch, state: State, gen_output: GeneratorOutput, dis_output: DiscriminatorOutput, losses: dict[str, torch.Tensor]) None[source]

Override this method to perform any logging.

This will avoid some annoying context manager issues.

run_model(model: GenerativeAdversarialNetworkModel[GeneratorT, DiscriminatorT], batch: Batch, state: State) tuple[GeneratorOutput, DiscriminatorOutput][source]

Runs a single training step and returns the outputs.

Parameters:
  • model – The current nn.Module

  • batch – The current batch

  • state – The current trainer state

Returns:

The outputs from the model

compute_loss(model: GenerativeAdversarialNetworkModel[GeneratorT, DiscriminatorT], batch: Batch, state: State, output: tuple[GeneratorOutput, DiscriminatorOutput]) dict[str, torch.Tensor][source]

Computes the loss for a given output.

If the loss is a tensor, it should have shape (B). If the loss is a dictionary of tensors, each tensor should have the same shape (B).

Parameters:
  • model – The current nn.Module

  • batch – The current batch

  • state – The current trainer state

  • output – The model output from run_model

Returns:

The computed loss, as a tensor or dictionary of tensors

on_after_gan_forward_step(generator: GeneratorT, discriminator: DiscriminatorT, batch: Batch, state: State, gen_output: GeneratorOutput, dis_output: DiscriminatorOutput) None[source]

GAN-specific hook that is called after a forward step.

This is useful for implementing the Wasserstein GAN gradient penalty.

Parameters:
  • generator – The generator model.

  • discriminator – The discriminator model.

  • batch – The batch to run the model on.

  • state – The current training state.

  • gen_output – The output of the generator model.

  • dis_output – The output of the discriminator model.

on_after_forward_step(model: GenerativeAdversarialNetworkModel[GeneratorT, DiscriminatorT], batch: Batch, output: tuple[GeneratorOutput, DiscriminatorOutput], state: State) None[source]