ml.trainers.mixins.mixed_precision

Defines a mixin for doing FP16 scaling.

FP16 scaling is a technique for training with FP16 precision while maintaining FP32 precision for the model weights. This is done by scaling the loss by a large factor (e.g. 2^16) and then scaling the gradients by the inverse of that factor. So if the scale factor starts to decrease, it means that the loss is overflowing and training is diverging.

class ml.trainers.mixins.mixed_precision.MixedPrecisionConfig(enabled: bool = True, init_scale: float = 65536.0, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000, min_grad_scale: float = 0.0001, foreach: bool | None = None)[source]

Bases: object

enabled: bool = True
init_scale: float = 65536.0
growth_factor: float = 2.0
backoff_factor: float = 0.5
growth_interval: int = 2000
min_grad_scale: float = 0.0001
foreach: bool | None = None
class ml.trainers.mixins.mixed_precision.MixedPrecisionTrainerConfig(name: str = '???', exp_name: str = '${ml.exp_name:null}', exp_dir: str = '???', log_dir_name: str = 'logs', use_double_weight_precision: bool = False, checkpoint: ml.trainers.base.CheckpointConfig = <factory>, mixed_precision: ml.trainers.mixins.mixed_precision.MixedPrecisionConfig = <factory>, clip_grad_norm: float = 10.0, clip_grad_norm_type: Any = 2, balance_grad_norms: bool = False)[source]

Bases: BaseTrainerConfig

mixed_precision: MixedPrecisionConfig
clip_grad_norm: float = 10.0
clip_grad_norm_type: Any = 2
balance_grad_norms: bool = False
class ml.trainers.mixins.mixed_precision.MixedPrecisionTrainerMixin(config: MixedPrecisionConfigT)[source]

Bases: BaseTrainer[MixedPrecisionConfigT, ModelT, TaskT]

Defines a trainer mixin for doing FP16 scaling.

scale_mixed_precision(tensor: Tensor) Tensor[source]
backward_grads(model: Module, loss: Tensor, loss_names: list[str], retain_graph: bool | None = None, inputs: Sequence[Tensor] | None = None) None[source]
step_optimizer(model: Module, optim: Optimizer, num_steps: int = 1) None[source]
log_mp_scale() None[source]
load_state_dict(state_dict: dict) None[source]

Function for loading state dict keys for different components.

Parameters:
  • state_dict – The state dictionary being saved (overriders should mutate inplace)

  • metadata – The metadata being saved (overriders should mutate inplace)

update_state_dict(state_dict: dict) None[source]

Function for getting the checkpoint to save.

Parameters:
  • state_dict – The state dictionary being saved (overriders should mutate inplace)

  • metadata – The metadata being saved (overriders should mutate inplace)