ml.models.init

Defines a general-purpose API for weight initialization.

from ml.models.init import init_, cast_init_type

linear = nn.Linear(32, 32)
init_(linear.weight, linear.bias, "orthogonal")

# This lets you parametrize the initialization type as a string.
init_(linear.weight, linear.bias, cast_init_type(my_init_type))

Choices for the initialization type are:

  • "orthogonal": Orthogonal initialization, meaning that the weights are initialized to an orthogonal matrix.

  • "normal": Initializes weights with a normal distribution

  • "biased_normal": Initializes both weights and biases with a normal distribution

  • "uniform": Initializes weights with a uniform distribution

  • "kaiming_uniform" or "kaiming_normal": Initializes weights with a Kaiming normal or uniform distribution

  • "xavier_uniform" or "xavier_normal": Initializes weights with a Xavier normal or uniform distribution

  • "zeros": Initializes weights to all zeros

  • "ones": Initializes weights to all ones

ml.models.init.cast_init_type(s: str) Literal['orthogonal', 'normal', 'biased_normal', 'uniform', 'kaiming_uniform', 'kaiming_normal', 'xavier_uniform', 'xavier_normal', 'trunc_normal', 'dirac', 'constant', 'zeros', 'ones'][source]
ml.models.init.init_(weight: Tensor, bias: Tensor | None, init: Literal['orthogonal', 'normal', 'biased_normal', 'uniform', 'kaiming_uniform', 'kaiming_normal', 'xavier_uniform', 'xavier_normal', 'trunc_normal', 'dirac', 'constant', 'zeros', 'ones'], *, mean: float = 0.0, std: float = 0.01, scale: float = 0.02, groups: int = 1, trunc_clip: tuple[float, float] = (-2.0, 2.0)) tuple[torch.Tensor, torch.Tensor | None][source]

Initializes the weight and bias in-place, using an initialization key.

The weight and bias are from a convolution or linear layer.

Parameters:
  • weight – The weight tensor

  • bias – The bias tensor

  • init – The initialization type to use

  • mean – The mean for normal initialization

  • std – The standard deviation for normal initialization

  • scale – The scale amount for uniform or constant initialization

  • groups – The number of groups, if argument is necessary

  • trunc_clip – The min and max values for trunc_normal initialization

Returns:

The initialized weight and bias (which can be discarded, since the initialization happens in-place).

Raises:

NotImplementedError – If the initialization mode isn’t implemented