ml.models.init
Defines a general-purpose API for weight initialization.
from ml.models.init import init_, cast_init_type
linear = nn.Linear(32, 32)
init_(linear.weight, linear.bias, "orthogonal")
# This lets you parametrize the initialization type as a string.
init_(linear.weight, linear.bias, cast_init_type(my_init_type))
Choices for the initialization type are:
"orthogonal"
: Orthogonal initialization, meaning that the weights are initialized to an orthogonal matrix."normal"
: Initializes weights with a normal distribution"biased_normal"
: Initializes both weights and biases with a normal distribution"uniform"
: Initializes weights with a uniform distribution"kaiming_uniform"
or"kaiming_normal"
: Initializes weights with a Kaiming normal or uniform distribution"xavier_uniform"
or"xavier_normal"
: Initializes weights with a Xavier normal or uniform distribution"zeros"
: Initializes weights to all zeros"ones"
: Initializes weights to all ones
- ml.models.init.cast_init_type(s: str) Literal['orthogonal', 'normal', 'biased_normal', 'uniform', 'kaiming_uniform', 'kaiming_normal', 'xavier_uniform', 'xavier_normal', 'trunc_normal', 'dirac', 'constant', 'zeros', 'ones'] [source]
- ml.models.init.init_(weight: Tensor, bias: Tensor | None, init: Literal['orthogonal', 'normal', 'biased_normal', 'uniform', 'kaiming_uniform', 'kaiming_normal', 'xavier_uniform', 'xavier_normal', 'trunc_normal', 'dirac', 'constant', 'zeros', 'ones'], *, mean: float = 0.0, std: float = 0.01, scale: float = 0.02, groups: int = 1, trunc_clip: tuple[float, float] = (-2.0, 2.0)) tuple[torch.Tensor, torch.Tensor | None] [source]
Initializes the weight and bias in-place, using an initialization key.
The weight and bias are from a convolution or linear layer.
- Parameters:
weight – The weight tensor
bias – The bias tensor
init – The initialization type to use
mean – The mean for normal initialization
std – The standard deviation for normal initialization
scale – The scale amount for uniform or constant initialization
groups – The number of groups, if argument is necessary
trunc_clip – The min and max values for trunc_normal initialization
- Returns:
The initialized weight and bias (which can be discarded, since the initialization happens in-place).
- Raises:
NotImplementedError – If the initialization mode isn’t implemented