ml.models.modules

Miscellaneous shared modules which can be used in various models.

ml.models.modules.scale_grad(x: Tensor, scale: float) Tensor[source]

Scales the gradient of the input.

Parameters:
  • x – Input tensor.

  • scale – Scale factor.

Returns:

The identity of the input tensor in the forward pass, and the scaled gradient in the backward pass.

ml.models.modules.invert_grad(x: Tensor) Tensor[source]
ml.models.modules.swap_grads(x: Tensor, y: Tensor) tuple[torch.Tensor, torch.Tensor][source]

Swaps the gradients of the inputs.

On the forward pass, this function returns the identity of the inputs. On the backward pass, the gradients of X and Y are swapped.

Parameters:
  • x – First input tensor.

  • y – Second input tensor.

Returns:

The identity of the inputs in the forward pass, and the swapped gradients in the backward pass.

ml.models.modules.combine_grads(x: Tensor, y: Tensor) tuple[torch.Tensor, torch.Tensor][source]

Combines the gradients of the inputs.

On the forward pass, this function returns the identity of the inputs. On the backward pass, the gradients of X and Y are summed.

Parameters:
  • x – First input tensor.

  • y – Second input tensor.

Returns:

The identity of the inputs in the forward pass, and the summed gradients in the backward pass.

ml.models.modules.streaming_conv_1d(x: Tensor, state: tuple[torch.Tensor, int] | None, weight: Tensor, bias: Tensor | None, stride: int, padding: int, dilation: int, groups: int) tuple[torch.Tensor, tuple[torch.Tensor, int]][source]

Applies a streaming convolution.

Parameters:
  • x – The input to the convolution.

  • state – The state of the convolution, which is the part of the previous input which is left over for computing the current convolution, along with an integer tracker for the number of samples to clip from the current input.

  • weight – The convolution weights.

  • bias – The convolution bias.

  • stride – The convolution stride.

  • padding – The convolution padding.

  • dilation – The convolution dilation.

  • groups – The convolution groups.

Returns:

The output of the convolution, plus the new state tracker.

ml.models.modules.streaming_conv_transpose_1d(x: Tensor, state: tuple[torch.Tensor, int] | None, weight: Tensor, bias: Tensor | None, stride: int, dilation: int, groups: int) tuple[torch.Tensor, tuple[torch.Tensor, int]][source]

Applies a streaming transposed convolution.

Parameters:
  • x – The input to the convolution.

  • state – The state of the convolution, which is the part of the previous input which is left over for computing the current convolution, along with an integer tracker for the number of samples to clip from the current input.

  • weight – The convolution weights.

  • bias – The convolution bias.

  • stride – The convolution stride.

  • dilation – The convolution dilation.

  • groups – The convolution groups.

Returns:

The output of the convolution, plus the new state tracker.

class ml.models.modules.StreamingConv1d(in_channels: int, out_channels: int, kernel_size: int | Tuple[int], stride: int | Tuple[int] = 1, padding: int | Tuple[int] = 0, dilation: int | Tuple[int] = 1, groups: int = 1, bias: bool = True)[source]

Bases: Conv1d

Defines a streaming 1D convolution layer.

This is analogous to streaming RNNs, where a state is maintained going forward in time. For convolutions, the state is simply the part of the previous input which is left over for computing the current convolution, along with an integer tracker for the number of samples to clip from the current input.

Note that this is a drop-in replacement for nn.Conv1d so far as the weights and biases go, but the forward pass takes an additional state argument and returns an additional state output.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

padding: tuple[int, ...]
forward(x: Tensor, state: tuple[torch.Tensor, int] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, int]][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ml.models.modules.StreamingConvTranspose1d(in_channels: int, out_channels: int, kernel_size: int | Tuple[int], stride: int | Tuple[int] = 1, dilation: int | Tuple[int] = 1, groups: int = 1, bias: bool = True)[source]

Bases: ConvTranspose1d

Defines a streaming 1D transposed convolution layer.

This is the inverse of StreamingConv1d, with the caveat that padding is not supported.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, state: tuple[torch.Tensor, int] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, int]][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

ml.models.modules.streaming_add(a: Tensor, b: Tensor, state: tuple[torch.Tensor, torch.Tensor] | None = None) tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]][source]

Performs streaming addition of two tensors.

Parameters:
  • a – The first tensor, with shape (B, C, T)

  • b – The second tensor, with shape (B, C, T)

  • state – The state of the addition, which is the leftover part from the previous addition.

Returns:

The sum of the two tensors, plus the new state.

ml.models.modules.streamable_cbr(in_channels: int, out_channels: int, kernel_size: int, dilation: int, norm: Literal['no_norm', 'batch', 'batch_affine', 'instance', 'instance_affine', 'group', 'group_affine', 'layer', 'layer_affine'] = 'batch_affine', act: Literal['no_act', 'relu', 'relu6', 'relu2', 'clamp6', 'leaky_relu', 'elu', 'celu', 'selu', 'gelu', 'gelu_fast', 'sigmoid', 'log_sigmoid', 'hard_sigomid', 'tanh', 'softsign', 'softplus', 'silu', 'mish', 'swish', 'hard_swish', 'soft_shrink', 'hard_shrink', 'tanh_shrink', 'soft_sign', 'relu_squared', 'laplace'] = 'gelu', bias: bool = False, groups: int = 1, group_norm_groups: int | None = None) Module[source]

Defines a streamable convolution-batchnorm-ReLU module.

This is a convenience function for defining a streamable convolution module. We pad the left side of the input so that each timestep only depends on previous timesteps, and not future timesteps, allowing us to compute the output of the convolution without having to wait for future timesteps.

Parameters:
  • in_channels – The number of input channels.

  • out_channels – The number of output channels.

  • kernel_size – The kernel size.

  • dilation – The dilation.

  • norm – The normalization type.

  • act – The activation type.

  • bias – Whether to use a bias. This should be turned off if the convolution is followed by a batch normalization layer.

  • groups – The number of groups for convolution.

  • group_norm_groups – The number of groups for group normalization.

Returns:

The streamable convolution module.

class ml.models.modules.residual(module: Module)[source]

Bases: Module

Defines a residual connection module.

The child module should take a single tensor as input and return a single tensor as output, with the same shape as the input.

Parameters:

module – The child module.

Inputs:

x: The input tensor, with shape (*)

Outputs:

The output tensor, with shape (*)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ml.models.modules.gated_residual(module: Module, gate: Module)[source]

Bases: Module

Defines a gated residual connection module.

The child module and gate should take a single tensor as input and return a single tensor as output, with the same shape as the input.

Parameters:
  • module – The child module.

  • gate – The gating module.

Inputs:

x: The input tensor, with shape (*)

Outputs:

The output tensor, with shape (*)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

ml.models.modules.drop_path(x: Tensor, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True) Tensor[source]
class ml.models.modules.DropPath(drop_prob: float = 0.0, scale_by_keep: bool = True)[source]

Bases: Module

Drop paths (Stochastic Depth) per sample.

This simulates stochastic depth for residual networks by randomly dropping out the residual tensor.

Parameters:
  • drop_path – The drop percentage to use.

  • scale_by_keep – If set, scale the non-dropped path to compensate for the dropped path.

Inputs:
x: The input tensor, with shape (*). This should be the residual

connection.

Outputs:

The identity function, if not training, or the stochastically dropped input tensor, with shape (*).

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

extra_repr() str[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.