ml.trainers.gan
Defines a trainer to use for training GANs.
This trainer is similar to the supervised learning trainer, but with separate optimizers for the generator and discriminator, and supporting round robin training.
- class ml.trainers.gan.GenerativeAdversarialNetworkTrainerConfig(name: str = '???', exp_name: str = '${ml.exp_name:null}', exp_dir: str = '???', log_dir_name: str = 'logs', use_double_weight_precision: bool = False, checkpoint: ml.trainers.base.CheckpointConfig = <factory>, parallel: ml.trainers.mixins.data_parallel.ParallelConfig = <factory>, compiler: ml.trainers.mixins.compile.TorchCompileConfig = <factory>, cpu_stats_ping_interval: int = 1, cpu_stats_only_log_once: bool = False, gpu_stats_ping_interval: int = 10, gpu_stats_only_log_once: bool = False, mixed_precision: ml.trainers.mixins.mixed_precision.MixedPrecisionConfig = <factory>, clip_grad_norm: float = 10.0, clip_grad_norm_type: Any = 2, balance_grad_norms: bool = False, profiler: ml.trainers.mixins.profiler.Profiler = <factory>, set_to_none: bool = True, deterministic: bool = False, use_tf32: bool = True, detect_anomaly: bool = False, validation: ml.trainers.sl.ValidationConfig = <factory>, batches_per_step: int = 1, batches_per_step_schedule: list[ml.trainers.sl.BatchScheduleConfig] | None = None, batch_chunks_per_step_schedule: list[ml.trainers.sl.BatchScheduleConfig] | None = None, batch_dim: int = 0, discriminator_key: str = 'dis', generator_key: str = 'gen', round_robin: bool = True)[source]
Bases:
SupervisedLearningTrainerConfig
- discriminator_key: str = 'dis'
- generator_key: str = 'gen'
- round_robin: bool = True
- class ml.trainers.gan.GenerativeAdversarialNetworkTrainer(config: SupervisedLearningTrainerConfigT)[source]
Bases:
SupervisedLearningTrainer
[GenerativeAdversarialNetworkTrainerConfigT
,GenerativeAdversarialNetworkModelT
,GenerativeAdversarialNetworkTaskT
],Generic
[GenerativeAdversarialNetworkTrainerConfigT
,GenerativeAdversarialNetworkModelT
,GenerativeAdversarialNetworkTaskT
]- gan_train_step(*, task_model: Module, params: tuple[list[torch.nn.parameter.Parameter], list[torch.nn.parameter.Parameter]], batches: Iterator[Batch], state: State, task: GenerativeAdversarialNetworkTaskT, model: GenerativeAdversarialNetworkModelT, optim: Optimizer | Collection[Optimizer], lr_sched: SchedulerAdapter | Collection[SchedulerAdapter]) dict[str, torch.Tensor] [source]
- get_params(task_model: Module) tuple[list[torch.nn.parameter.Parameter], list[torch.nn.parameter.Parameter]] [source]
- train(model: GenerativeAdversarialNetworkModelT, task: GenerativeAdversarialNetworkTaskT, optimizer: BaseOptimizer, lr_scheduler: BaseLRScheduler) None [source]
Runs the training loop.
- Parameters:
model – The model to train.
task – The task to train on.
optimizer – The optimizer to use.
lr_scheduler – The learning rate scheduler to use.