# mypy: disable-error-code="override"
"""Miscellaneous shared modules which can be used in various models."""
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.autograd.function import Function, FunctionCtx
from torch.nn.common_types import _size_1_t
from ml.models.activations import ActivationType, get_activation
from ml.models.norms import NormType, get_norm_1d
class _InvertGrad(Function):
@staticmethod
def forward(ctx: FunctionCtx, input: Tensor, scale: float) -> Tensor:
ctx.scale = scale
return input
@staticmethod
def backward(ctx: FunctionCtx, grad_output: Tensor) -> tuple[Tensor, None]:
return grad_output * ctx.scale, None
[docs]def scale_grad(x: Tensor, scale: float) -> Tensor:
"""Scales the gradient of the input.
Args:
x: Input tensor.
scale: Scale factor.
Returns:
The identity of the input tensor in the forward pass, and the scaled
gradient in the backward pass.
"""
return _InvertGrad.apply(x, scale)
[docs]def invert_grad(x: Tensor) -> Tensor:
return scale_grad(x, -1.0)
class _SwapGrads(Function):
@staticmethod
def forward(ctx: FunctionCtx, x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
return x, y
@staticmethod
def backward(ctx: FunctionCtx, grad_x: Tensor, grad_y: Tensor) -> tuple[Tensor, Tensor]:
return grad_y, grad_x
[docs]def swap_grads(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
"""Swaps the gradients of the inputs.
On the forward pass, this function returns the identity of the inputs.
On the backward pass, the gradients of X and Y are swapped.
Args:
x: First input tensor.
y: Second input tensor.
Returns:
The identity of the inputs in the forward pass, and the swapped
gradients in the backward pass.
"""
return _SwapGrads.apply(x, y)
class _CombineGrads(Function):
@staticmethod
def forward(ctx: FunctionCtx, x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
return x, y
@staticmethod
def backward(ctx: FunctionCtx, grad_x: Tensor, grad_y: Tensor) -> tuple[Tensor, Tensor]:
grad = grad_x + grad_y
return grad, grad
[docs]def combine_grads(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
"""Combines the gradients of the inputs.
On the forward pass, this function returns the identity of the inputs.
On the backward pass, the gradients of X and Y are summed.
Args:
x: First input tensor.
y: Second input tensor.
Returns:
The identity of the inputs in the forward pass, and the summed
gradients in the backward pass.
"""
return _CombineGrads.apply(x, y)
[docs]def streaming_conv_1d(
x: Tensor,
state: tuple[Tensor, int] | None,
weight: Tensor,
bias: Tensor | None,
stride: int,
padding: int,
dilation: int,
groups: int,
) -> tuple[Tensor, tuple[Tensor, int]]:
"""Applies a streaming convolution.
Args:
x: The input to the convolution.
state: The state of the convolution, which is the part of the previous
input which is left over for computing the current convolution,
along with an integer tracker for the number of samples to clip
from the current input.
weight: The convolution weights.
bias: The convolution bias.
stride: The convolution stride.
padding: The convolution padding.
dilation: The convolution dilation.
groups: The convolution groups.
Returns:
The output of the convolution, plus the new state tracker.
"""
if state is None:
pre_t = 0
if padding > 0:
x = torch.cat((x.new_zeros(x.shape[:-1] + (padding,)), x), dim=-1)
else:
pre_x, pre_t = state
x = torch.cat((pre_x, x), dim=-1)
if pre_t > 0:
pre_t, x = pre_t - x.shape[-1], x[..., pre_t:]
(bsz, _, tsz), (chsz_out, _, ksize) = x.shape, weight.shape
min_tsz = 1 + (ksize - 1) * dilation
if tsz < min_tsz:
return x.new_zeros(bsz, chsz_out, 0), (x, pre_t)
y = F.conv1d(x, weight, bias, stride, 0, dilation, groups)
t = stride * y.shape[-1]
return y, (x[:, :, t:], max(0, t - tsz))
[docs]def streaming_conv_transpose_1d(
x: Tensor,
state: tuple[Tensor, int] | None,
weight: Tensor,
bias: Tensor | None,
stride: int,
dilation: int,
groups: int,
) -> tuple[Tensor, tuple[Tensor, int]]:
"""Applies a streaming transposed convolution.
Args:
x: The input to the convolution.
state: The state of the convolution, which is the part of the previous
input which is left over for computing the current convolution,
along with an integer tracker for the number of samples to clip
from the current input.
weight: The convolution weights.
bias: The convolution bias.
stride: The convolution stride.
dilation: The convolution dilation.
groups: The convolution groups.
Returns:
The output of the convolution, plus the new state tracker.
"""
y = F.conv_transpose1d(x, weight, bias, stride, 0, 0, groups, dilation)
bsz, chsz_out, tsz = y.shape
t = stride * x.shape[-1]
if state is not None:
post_y, post_t = state
if post_t > 0:
init_y = y.new_zeros(bsz, chsz_out, post_t)
if bias is not None:
init_y += bias[..., None]
y = torch.cat([init_y, y], dim=-1)
n = min(post_y.shape[-1], y.shape[-1])
init_y = post_y[..., :n] + y[..., :n]
if bias is not None:
init_y -= bias[..., None]
y = torch.cat((init_y, post_y[..., n:], y[..., n:]), dim=-1)
return y[..., :t], (y[..., t:], max(0, t - tsz))
[docs]class StreamingConv1d(nn.Conv1d):
"""Defines a streaming 1D convolution layer.
This is analogous to streaming RNNs, where a state is maintained going
forward in time. For convolutions, the state is simply the part of the
previous input which is left over for computing the current convolution,
along with an integer tracker for the number of samples to clip from the
current input.
Note that this is a drop-in replacement for ``nn.Conv1d`` so far as the
weights and biases go, but the forward pass takes an additional state
argument and returns an additional state output.
"""
padding: tuple[int, ...]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
) -> None:
super().__init__(
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
in_channels=in_channels,
out_channels=out_channels,
)
assert isinstance(self.padding, tuple) and len(self.padding) == 1 and isinstance(self.padding[0], int)
[docs] def forward(
self,
x: Tensor,
state: tuple[Tensor, int] | None = None,
) -> tuple[Tensor, tuple[Tensor, int]]:
weight, bias = self.weight, self.bias
stride, padding, dilation, groups = self.stride[0], self.padding[0], self.dilation[0], self.groups
return streaming_conv_1d(x, state, weight, bias, stride, padding, dilation, groups)
[docs]class StreamingConvTranspose1d(nn.ConvTranspose1d):
"""Defines a streaming 1D transposed convolution layer.
This is the inverse of ``StreamingConv1d``, with the caveat that padding
is not supported.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
) -> None:
super().__init__(
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
in_channels=in_channels,
out_channels=out_channels,
)
[docs] def forward(
self,
x: Tensor,
state: tuple[Tensor, int] | None = None,
) -> tuple[Tensor, tuple[Tensor, int]]:
weight, bias = self.weight, self.bias
stride, dilation, groups = self.stride[0], self.dilation[0], self.groups
return streaming_conv_transpose_1d(x, state, weight, bias, stride, dilation, groups)
[docs]def streaming_add(
a: Tensor,
b: Tensor,
state: tuple[Tensor, Tensor] | None = None,
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
"""Performs streaming addition of two tensors.
Args:
a: The first tensor, with shape ``(B, C, T)``
b: The second tensor, with shape ``(B, C, T)``
state: The state of the addition, which is the leftover part from the
previous addition.
Returns:
The sum of the two tensors, plus the new state.
"""
if state is not None:
prev_a, prev_b = state
a = torch.cat((prev_a, a), dim=-1)
b = torch.cat((prev_b, b), dim=-1)
n = min(a.shape[-1], b.shape[-1])
y = a[..., :n] + b[..., :n]
return y, (a[..., n:], b[..., n:])
[docs]def streamable_cbr(
in_channels: int,
out_channels: int,
kernel_size: int,
dilation: int,
norm: NormType = "batch_affine",
act: ActivationType = "gelu",
bias: bool = False,
groups: int = 1,
group_norm_groups: int | None = None,
) -> nn.Module:
"""Defines a streamable convolution-batchnorm-ReLU module.
This is a convenience function for defining a streamable convolution
module. We pad the left side of the input so that each timestep only
depends on previous timesteps, and not future timesteps, allowing us to
compute the output of the convolution without having to wait for future
timesteps.
Args:
in_channels: The number of input channels.
out_channels: The number of output channels.
kernel_size: The kernel size.
dilation: The dilation.
norm: The normalization type.
act: The activation type.
bias: Whether to use a bias. This should be turned off if the
convolution is followed by a batch normalization layer.
groups: The number of groups for convolution.
group_norm_groups: The number of groups for group normalization.
Returns:
The streamable convolution module.
"""
pad_amount = (kernel_size - 1) * dilation
return nn.Sequential(
nn.ConstantPad1d((pad_amount, 0), 0.0) if pad_amount > 0 else nn.Identity(),
nn.Conv1d(in_channels, out_channels, kernel_size, groups=groups, dilation=dilation, bias=bias),
get_norm_1d(norm, dim=out_channels, groups=group_norm_groups),
get_activation(act, inplace=True),
)
[docs]class residual(nn.Module): # noqa: N801
"""Defines a residual connection module.
The child module should take a single tensor as input and return a single
tensor as output, with the same shape as the input.
Parameters:
module: The child module.
Inputs:
x: The input tensor, with shape ``(*)``
Outputs:
The output tensor, with shape ``(*)``
"""
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module
[docs] def forward(self, x: Tensor) -> Tensor:
return x + self.module(x)
[docs]class gated_residual(nn.Module): # noqa: N801
"""Defines a gated residual connection module.
The child module and gate should take a single tensor as input and return
a single tensor as output, with the same shape as the input.
Parameters:
module: The child module.
gate: The gating module.
Inputs:
x: The input tensor, with shape ``(*)``
Outputs:
The output tensor, with shape ``(*)``
"""
def __init__(self, module: nn.Module, gate: nn.Module) -> None:
super().__init__()
self.module = module
self.gate = gate
[docs] def forward(self, x: Tensor) -> Tensor:
y, g = self.module(x), self.gate(x)
return x * g + y * (1 - g)
[docs]def drop_path(x: Tensor, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True) -> Tensor:
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
[docs]class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample.
This simulates stochastic depth for residual networks by randomly dropping
out the residual tensor.
Parameters:
drop_path: The drop percentage to use.
scale_by_keep: If set, scale the non-dropped path to compensate for
the dropped path.
Inputs:
x: The input tensor, with shape ``(*)``. This should be the residual
connection.
Outputs:
The identity function, if not training, or the stochastically dropped
input tensor, with shape ``(*)``.
"""
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True) -> None:
super().__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
[docs] def forward(self, x: Tensor) -> Tensor:
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)