ml.tasks.gan.base
Defines the base GAN task type.
This class expects you to implement the following functions:
class MyGanTask(
ml.GenerativeAdversarialNetworkTask[
Config,
Generator,
Discriminator,
Batch,
GeneratorOutput,
DiscriminatorOutput,
Loss,
],
):
def run_generator(self, model: Generator, batch: Batch, state: ml.State) -> GeneratorOutput:
...
def run_discriminator(
self,
model: Discriminator,
batch: Batch,
gen_output: GeneratorOutput,
state: ml.State,
) -> DiscriminatorOutput:
...
def compute_discriminator_loss(
self,
generator: Generator,
discriminator: Discriminator,
batch: Batch,
state: ml.State,
gen_output: GeneratorOutput,
dis_output: DiscriminatorOutput,
) -> Loss:
...
def get_dataset(self, phase: ml.Phase) -> Dataset:
...
- class ml.tasks.gan.base.GenerativeAdversarialNetworkTaskConfig(max_epochs: int | None = None, max_steps: int | None = None, max_samples: int | None = None, max_seconds: float | None = None, train_dl: ml.tasks.base.DataLoaderConfig = <factory>, valid_dl: ml.tasks.base.DataLoaderConfig = <factory>, test_dl: ml.tasks.base.DataLoaderConfig = <factory>, name: str = '???', errors: ml.tasks.datasets.error_handling.ErrorHandlingConfig = <factory>)[source]
Bases:
SupervisedLearningTaskConfig
- class ml.tasks.gan.base.GenerativeAdversarialNetworkTask(config: BaseTaskConfigT)[source]
Bases:
SupervisedLearningTask
[GenerativeAdversarialNetworkTaskConfigT
,GenerativeAdversarialNetworkModel
[GeneratorT
,DiscriminatorT
],Batch
,tuple
[GeneratorOutput
,DiscriminatorOutput
],dict
[str
,Tensor
]],Generic
[GenerativeAdversarialNetworkTaskConfigT
,GeneratorT
,DiscriminatorT
,Batch
,GeneratorOutput
,DiscriminatorOutput
],ABC
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- abstract run_generator(generator: GeneratorT, batch: Batch, state: State) GeneratorOutput [source]
Runs the generator model on the given batch.
- Parameters:
generator – The generator module.
batch – The batch to run the model on.
state – The current training state.
- Returns:
The output of the generator model
- abstract run_discriminator(discriminator: DiscriminatorT, batch: Batch, gen_output: GeneratorOutput, state: State) DiscriminatorOutput [source]
Runs the discriminator model on the given batch.
- Parameters:
discriminator – The discriminator model.
batch – The batch to run the model on.
gen_output – The output of the generator model.
state – The current training state.
- Returns:
The output of the discriminator model
- abstract compute_discriminator_loss(generator: GeneratorT, discriminator: DiscriminatorT, batch: Batch, state: State, gen_output: GeneratorOutput, dis_output: DiscriminatorOutput) dict[str, torch.Tensor] [source]
Computes the discriminator loss for the given batch.
- Parameters:
generator – The generator model.
discriminator – The discriminator model.
batch – The batch to run the model on.
state – The current training state.
gen_output – The output of the generator model.
dis_output – The output of the discriminator model.
- Returns:
The discriminator loss.
- compute_generator_loss(generator: GeneratorT, discriminator: DiscriminatorT, batch: Batch, state: State, gen_output: GeneratorOutput, dis_output: DiscriminatorOutput) dict[str, torch.Tensor] [source]
- do_logging(generator: GeneratorT, discriminator: DiscriminatorT, batch: Batch, state: State, gen_output: GeneratorOutput, dis_output: DiscriminatorOutput, losses: dict[str, torch.Tensor]) None [source]
Override this method to perform any logging.
This will avoid some annoying context manager issues.
- run_model(model: GenerativeAdversarialNetworkModel[GeneratorT, DiscriminatorT], batch: Batch, state: State) tuple[GeneratorOutput, DiscriminatorOutput] [source]
Runs a single training step and returns the outputs.
- Parameters:
model – The current nn.Module
batch – The current batch
state – The current trainer state
- Returns:
The outputs from the model
- compute_loss(model: GenerativeAdversarialNetworkModel[GeneratorT, DiscriminatorT], batch: Batch, state: State, output: tuple[GeneratorOutput, DiscriminatorOutput]) dict[str, torch.Tensor] [source]
Computes the loss for a given output.
If the loss is a tensor, it should have shape (B). If the loss is a dictionary of tensors, each tensor should have the same shape (B).
- Parameters:
model – The current nn.Module
batch – The current batch
state – The current trainer state
output – The model output from run_model
- Returns:
The computed loss, as a tensor or dictionary of tensors
- separate_losses(losses: dict[str, torch.Tensor]) tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]] [source]
- on_after_gan_forward_step(generator: GeneratorT, discriminator: DiscriminatorT, batch: Batch, state: State, gen_output: GeneratorOutput, dis_output: DiscriminatorOutput) None [source]
GAN-specific hook that is called after a forward step.
This is useful for implementing the Wasserstein GAN gradient penalty.
- Parameters:
generator – The generator model.
discriminator – The discriminator model.
batch – The batch to run the model on.
state – The current training state.
gen_output – The output of the generator model.
dis_output – The output of the discriminator model.
- on_after_forward_step(model: GenerativeAdversarialNetworkModel[GeneratorT, DiscriminatorT], batch: Batch, output: tuple[GeneratorOutput, DiscriminatorOutput], state: State) None [source]