ml.trainers.gan

Defines a trainer to use for training GANs.

This trainer is similar to the supervised learning trainer, but with separate optimizers for the generator and discriminator, and supporting round robin training.

class ml.trainers.gan.GenerativeAdversarialNetworkTrainerConfig(name: str = '???', exp_name: str = '${ml.exp_name:null}', exp_dir: str = '???', log_dir_name: str = 'logs', use_double_weight_precision: bool = False, checkpoint: ml.trainers.base.CheckpointConfig = <factory>, parallel: ml.trainers.mixins.data_parallel.ParallelConfig = <factory>, compiler: ml.trainers.mixins.compile.TorchCompileConfig = <factory>, cpu_stats_ping_interval: int = 1, cpu_stats_only_log_once: bool = False, gpu_stats_ping_interval: int = 10, gpu_stats_only_log_once: bool = False, mixed_precision: ml.trainers.mixins.mixed_precision.MixedPrecisionConfig = <factory>, clip_grad_norm: float = 10.0, clip_grad_norm_type: Any = 2, balance_grad_norms: bool = False, profiler: ml.trainers.mixins.profiler.Profiler = <factory>, set_to_none: bool = True, deterministic: bool = False, use_tf32: bool = True, detect_anomaly: bool = False, validation: ml.trainers.sl.ValidationConfig = <factory>, batches_per_step: int = 1, batches_per_step_schedule: list[ml.trainers.sl.BatchScheduleConfig] | None = None, batch_chunks_per_step_schedule: list[ml.trainers.sl.BatchScheduleConfig] | None = None, batch_dim: int = 0, discriminator_key: str = 'dis', generator_key: str = 'gen', round_robin: bool = True)[source]

Bases: SupervisedLearningTrainerConfig

discriminator_key: str = 'dis'
generator_key: str = 'gen'
round_robin: bool = True
class ml.trainers.gan.GenerativeAdversarialNetworkTrainer(config: SupervisedLearningTrainerConfigT)[source]

Bases: SupervisedLearningTrainer[GenerativeAdversarialNetworkTrainerConfigT, GenerativeAdversarialNetworkModelT, GenerativeAdversarialNetworkTaskT], Generic[GenerativeAdversarialNetworkTrainerConfigT, GenerativeAdversarialNetworkModelT, GenerativeAdversarialNetworkTaskT]

gan_train_step(*, task_model: Module, params: tuple[list[torch.nn.parameter.Parameter], list[torch.nn.parameter.Parameter]], batches: Iterator[Batch], state: State, task: GenerativeAdversarialNetworkTaskT, model: GenerativeAdversarialNetworkModelT, optim: Optimizer | Collection[Optimizer], lr_sched: SchedulerAdapter | Collection[SchedulerAdapter]) dict[str, torch.Tensor][source]
get_params(task_model: Module) tuple[list[torch.nn.parameter.Parameter], list[torch.nn.parameter.Parameter]][source]
train(model: GenerativeAdversarialNetworkModelT, task: GenerativeAdversarialNetworkTaskT, optimizer: BaseOptimizer, lr_scheduler: BaseLRScheduler) None[source]

Runs the training loop.

Parameters:
  • model – The model to train.

  • task – The task to train on.

  • optimizer – The optimizer to use.

  • lr_scheduler – The learning rate scheduler to use.