ml.trainers.mixins.grad_balancer
Implemenents a modified loss balancer.
The loss balancer balances the gradients of multiple losses. For each loss, the gradients are scaled by the norm of the loss, so that the total norm after all the losses are backpropagated is equal to the total_norm parameter.
- class ml.trainers.mixins.grad_balancer.GradBalancer(logger: MultiLogger | None = None, total_norm: float = 1.0, epsilon: float = 0.0001, set_to_none: bool = True, norm_type: float = 2.0, foreach: bool | None = None)[source]
Bases:
object