ml.models.architectures.attention

Defines self-attention modules.

You can implement a self-attention model using the built-in PyTorch module:

from torch import nn

self.attn = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(
        d_model=512,
        nhead=8,
        dim_feedforward=2048,
        dropout=0.1,
        activation='relu',
        batch_first=True,
    ),
    num_layers=6,
)

However, when doing inference, you will end up recomputing a lot of previous states. Instead, you can use the equivalent implementation in this file:

from ml.models.architectures.attention import TransformerEncoder, TransformerEncoderLayer

self.attn = TransformerEncoder(
    TransformerEncoderLayer(
        d_model=512,
        nhead=8,
        dim_feedforward=2048,
        dropout=0.1,
        # activation='relu',  Always ReLU
        # batch_first=True,  Always batch first
        is_causal=is_causal,  # Additional argument to support causal attention
        use_rotary=use_rotary,  # Additional argument to support rotary embeddings
    ),
    num_layers=6,
)

x, state = self.attn(x, state)

This also eliminates the need to pass in an attention mask; instead, simply use the is_causal argument to the forward method and it will automatically apply the mask for you. This will default to the more performant PyTorch attention implementation.

ml.models.architectures.attention.get_attention_mask(mode: Literal['causal'], *, tsz_q: int, tsz_k: int, device: torch.device | None = None, dtype: torch.dtype | None = None) Tensor[source]
ml.models.architectures.attention.get_attention_mask(mode: Literal['lengths'], *, lengths: Tensor, tsz_k: int | None = None, device: torch.device | None = None, dtype: torch.dtype | None = None) Tensor

Returns a causal attention mask.

Parameters:
  • mode – Causal attention mode.

  • lengths – The lengths tensor, of shape (bsz). Only required if mode="lengths".

  • tsz_q – The number of queries.

  • tsz_k – The number of keys.

  • device – The output device.

  • dtype – The output dtype.

Returns:

If in causal mode, returns a causal attention mask with shape (tsz_q, tsz_k). If in lengths mode, will return an attention mask with shape (bsz, tsz_k). If the dtype is boolean, will have True values for queries and keys that should attend to each other, False otherwise. If a float, will have have values of 0 for queries and keys that should attend to each other, and -inf otherwise, so that the mask can be applied by being added to the pre-softmax attention matrix.

class ml.models.architectures.attention.MultiheadAttention(embed_dim: int, head_dim: int, dropout: float = 0.0, bias: bool = True, kdim: int | None = None, vdim: int | None = None, gqa_factor: int = 1)[source]

Bases: Module

Defines a streamable multihead attention layer.

This is a slightly modified implementation of nn.MultiheadAttention that is built into PyTorch. The main difference is that this version supports streaming inference for causal attention, by passing in a state tuple that contains the previously projected key and value tensors.

Parameters:
  • embed_dim – The input and output embedding dimension.

  • head_dim – The number of dimensions in each attention head.

  • dropout – The dropout probability, applied to the attention matrix.

  • bias – Whether to include a bias term in the projection layers.

  • kdim – The dimension of the key projection. Defaults to embed_dim.

  • vdim – The dimension of the value projection. Defaults to embed_dim.

  • gqa_factor – The GQA factor to use, meaning the ratio of the number of queries to the number of keys. Higher values will result in more queries than keys, which can speed up inference.

Inputs:

query: The query tensor, of shape (B, T, C). key: The key tensor, of shape (B, T, C). value: The value tensor, of shape (B, T, C). state: The previous key and value tensors, of shape

(B * H, T', C // H), where T' is the number of previous timesteps and H is the number of attention heads. This is only supported if is_causal=True.

is_causal: Whether to apply a causal mask to the attention matrix.

Note that the “mask” is only applied implicitly and isn’t actually instantiated as a tensor.

Outputs:
output: The output tensor, of shape (B, T, C), along with the

key and value state for the next timestep.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward_matmuls(query: Tq, key: Tk, value: Tv, rotary_q: Tensor | None = None, rotary_k: Tensor | None = None) tuple[Tq, Tk, Tv][source]
forward_attn(xq: Tensor, xk: Tensor, xv: Tensor, is_causal: bool = False, mask: Tensor | None = None) Tensor[source]
forward(query: Tensor, key: Tensor, value: Tensor, is_causal: bool = False, rotary_q: Tensor | None = None, rotary_k: Tensor | None = None, mask: Tensor | None = None) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_attn_matrix(xq: Tensor, xk: Tensor, is_causal: bool = False, mask: Tensor | None = None) Tensor[source]

Computes the attention matrix for a given query and key.

This function can be used for visualization purposes.

Parameters:
  • xq – The query embeddings, with shape (B, G, H, Tq, C)

  • xk – The key embeddings, with shape (B, G, H, Tk, C)

  • state – The previous state tensor.

  • is_causal – Whether to apply a causal mask to the attention matrix. In this function, unlike in the forward pass, the mask is explicitly created if not provided.

  • mask – The attention mask, of shape (B, Tq, Tk). If None, don’t apply an attention mask.

Returns:

The attention matrix, of shape (B, G, H, Tq, Tk).

class ml.models.architectures.attention.TransformerEncoderLayer(d_model: int, head_dims: int = 64, feedforward_factor: int = 4, dropout: float = 0.1, layer_norm_eps: float = 1e-05, norm_first: bool = False, gqa_factor: int = 1, max_kv_cache_len: int | None = None)[source]

Bases: Module

Defines a transformer encoder layer.

This layer is a drop-in replacement for nn.TransformerEncoderLayer except that it returns the attention state for causal attention, which can be used to implement streaming inference.

Parameters:
  • d_model – The input and output embedding dimension.

  • head_dims – The number of dimensions in each attention head.

  • feedforward_factor – The factor by which the input number of dimensions is multiplied to get the feedforward hidden dimension.

  • dropout – The dropout probability, applied to the attention matrix.

  • layer_norm_eps – The layer normalization epsilon value.

  • norm_first – Whether to apply layer normalization before the attention layer.

  • gqa_factor – The GQA factor to use, meaning the ratio of the number of queries to the number of keys. Higher values will result in more queries than keys, which can speed up inference.

  • max_kv_cache_len – The maximum number of previous timesteps to cache for the key and value tensors. If None, don’t clip the maximum length.

Inputs:

src: The input tensor, of shape (B, T, C). state: The previous state tensor, if applicable. is_causal: Whether to apply a causal mask to the attention matrix.

Note that the “mask” is only applied implicitly and isn’t actually instantiated as a tensor.

rotary_q: The rotary embeddings for the query tensor, of shape

(G, H, C // H). If None, don’t apply rotary embeddings.

rotary_k: The rotary embeddings for the key tensor, of shape

(G, H, C // H). If None, don’t apply rotary embeddings.

mask: The attention mask, of shape (B, Tq, Tk). If None, don’t

apply an attention mask.

Outputs:

output: The output tensor, of shape (B, T, C). state: The next state tensor.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(src: Tensor, state: Tensor | None = None, is_causal: bool = False, rotary_q: Tensor | None = None, rotary_k: Tensor | None = None, mask: Tensor | None = None) tuple[torch.Tensor, torch.Tensor][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ml.models.architectures.attention.TransformerDecoderLayer(d_model: int, head_dims: int = 64, feedforward_factor: int = 4, dropout: float = 0.1, layer_norm_eps: float = 1e-05, norm_first: bool = False, gqa_factor: int = 1, memory_dims: int | None = None)[source]

Bases: Module

Defines a transformer decoder layer.

Unlike the PyTorch decoder layer, this layer only contains cross-attention. To mimic the original behavior, pair this layer with a self-attention layer.

Parameters:
  • d_model – The input and output embedding dimension.

  • head_dims – The number of dimensions in each attention head.

  • feedforward_factor – The factor by which the input number of dimensions is multiplied to get the feedforward hidden dimension.

  • dropout – The dropout probability, applied to the attention matrix.

  • layer_norm_eps – The layer normalization epsilon value.

  • norm_first – Whether to apply layer normalization before the attention layer.

  • gqa_factor – The GQA factor to use, meaning the ratio of the number of queries to the number of keys. Higher values will result in more queries than keys, which can speed up inference.

  • memory_dims – The number of dimensions in the memory tensor; if not provided, defaults to d_model.

Inputs:

src: The input tensor, of shape (B, Tq, C). memory: The memory tensor, of shape (B, Tk, C) state: The previous state tensor, if applicable. rotary_q: The rotary embeddings for the query tensor, of shape

(G, H, C // H). If None, don’t apply rotary embeddings.

rotary_k: The rotary embeddings for the key tensor, of shape

(G, H, C // H). If None, don’t apply rotary embeddings.

mask: The attention mask, of shape (B, Tq, Tk). If None, don’t

apply an attention mask.

Outputs:

output: The output tensor, of shape (B, Tq, C). state: The next state tensor.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(src: Tensor, memory: Tensor, state: Tensor | None = None, mask: Tensor | None = None) tuple[torch.Tensor, torch.Tensor][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ml.models.architectures.attention.TransformerEncoder(encoder_layer: TransformerEncoderLayer, num_layers: int, norm: LayerNorm | None = None, is_causal: bool | None = None, use_rotary: bool = False, rotary_base: int = 10000)[source]

Bases: Module

Defines a transformer encoder.

This is a drop-in replacement for nn.TransformerEncoder except that it returns the attention state for causal attention, which can be used to implement streaming inference.

This additionally supports using rotary embeddings for the key-query matrix multiplications. The rotary embedding tensors are computed at runtime and cached.

Parameters:
  • encoder_layer – The encoder layer to use.

  • num_layers – The number of encoder layers.

  • norm – The normalization layer to use. Defaults to None.

  • is_causal – Default value for is_causal in the forward method if not supplied. Controls causal verses bidirectional attention.

  • use_rotary – Default value for use_rotary in the forward method if not supplied. Controls the use of rotary embeddings in the key-query matrix multiplication.

  • rotary_base – The base value for rotary embeddings.

Inputs:

src: The input tensor, of shape (B, T, C). state: The previous state tensor, if applicable. is_causal: Whether to apply a causal mask to the attention matrix.

Note that the “mask” is only applied implicitly and isn’t actually instantiated as a tensor.

use_rotary: If set, use rotary embeddings in the key-query matrix

multiplication.

mask: The attention mask, of shape (B, Tq, Tk). If None, don’t

apply an attention mask.

Outputs:

output: The output tensor, of shape (B, T, C). state: The previous state tensor, if applicable.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(src: Tensor, state: Tensor | None = None, is_causal: bool | None = None, use_rotary: bool | None = None, mask: Tensor | None = None) tuple[torch.Tensor, torch.Tensor][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ml.models.architectures.attention.TransformerDecoder(encoder_layer: TransformerEncoderLayer, decoder_layer: TransformerDecoderLayer, num_layers: int, norm: LayerNorm | None = None, is_causal: bool | None = None, use_rotary: bool = False, rotary_base: int = 10000)[source]

Bases: Module

Defines a transformer decoder.

Parameters:
  • encoder_layer – The encoder layer to use.

  • num_layers – The number of encoder layers.

  • norm – The normalization layer to use. Defaults to None.

  • is_causal – Default value for is_causal in the forward method if not supplied. Controls causal verses bidirectional attention.

  • use_rotary – Default value for use_rotary in the forward method if not supplied. Controls the use of rotary embeddings in the key-query matrix multiplication.

  • rotary_base – The base value for rotary embeddings.

Inputs:

src: The input tensor, of shape (B, Tq, C). memory: The memory tensor, of shape (B, Tk, C). state: The previous state tensor, if applicable. is_causal: Whether to apply a causal mask to the attention matrix.

Note that the “mask” is only applied implicitly and isn’t actually instantiated as a tensor.

use_rotary: If set, use rotary embeddings in the key-query matrix

multiplication.

encoder_mask: The encoder attention mask, of shape (B, Tq, Tq).

If None, don’t apply an attention mask to the encoder.

decoder_mask: The decoder attention mask, of shape (B, Tq, Tk).

If None, don’t apply an attention mask to the decoder.

Outputs:

output: The output tensor, of shape (B, Tq, C). state: The previous state tensor, if applicable.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(src: Tensor, memory: Tensor, state: tuple[torch.Tensor, torch.Tensor] | None = None, is_causal: bool | None = None, use_rotary: bool | None = None, encoder_mask: Tensor | None = None, decoder_mask: Tensor | None = None) tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

ml.models.architectures.attention.nucleus_sampling(logits: Tensor, p: float, temperature: float = 1.0, dim: int = -1) Tensor[source]

Samples from a distribution using nucleus sampling.

This is a modified version of torch.multinomial that uses nucleus sampling instead of top-k sampling. The difference is that top-k sampling sets the probability of all values outside the top-k to zero, whereas nucleus sampling sets the probability of all values outside the top-p to zero.

Parameters:
  • logits – The input tensor, of shape (B, T, C).

  • p – The probability threshold.

  • temperature – The temperature to apply to the logits.

  • dim – The dimension to sample from. Defaults to -1.

Returns:

The sampled indices, of shape (B, T).