ml.lr_schedulers.gan

GAN learning rate scheduler wrapper.

This wrapper allows for downstream users to set different learning rate schedules for the generator and discriminator of a GAN.

This class is used by the GAN trainer interface and shouldn’t be used elsewhere.

class ml.lr_schedulers.gan.GenerativeAdversarialNetworkLRSchedulerConfig(name: str = '???', generator: Any = '???', discriminator: Any = '???')[source]

Bases: BaseLRSchedulerConfig

generator: Any = '???'
discriminator: Any = '???'
classmethod update(config: DictConfig) DictConfig[source]

Runs post-construction config update.

Parameters:

config – The config to update

classmethod resolve(config: GenerativeAdversarialNetworkLRSchedulerConfig) None[source]

Runs post-construction config resolution.

Parameters:

config – The config to resolve

class ml.lr_schedulers.gan.GenerativeAdversarialNetworkLRScheduler(config: GenerativeAdversarialNetworkLRSchedulerConfig)[source]

Bases: BaseLRScheduler[GenerativeAdversarialNetworkLRSchedulerConfig]

get_lr_scale(state: State) float[source]

Given a state, returns the current learning rate.

Parameters:

state – The current trainer state

Returns:

The computed learning rate to use