ml.utils.attention
Defines some augmentations for attention functions.
- class ml.utils.attention.NextTokenDiscriminator(emb_dim: int, max_tsz: int)[source]
Bases:
Module
Defines a module for training a discriminator on next token predictions.
Consider doing GAN-style training on an autoregressive transformer model. For an input sequence with shape
(T, C)
, the generator model outputs a next token prediction for each timestep, giving a tensor with shape(T, C)
. The discriminator model then conditions on the ground truth tensors and the predicted next tokens to give a discriminator score. The trick is that for each of the predicted token distributions, the discriminator should only be able to see the previous tokens and not the input token.This module takes the input tensors described above, applies an “initial token” to the first one, and concatenates the two tensors to get the input tensor to use when training the generator model. It also returns the attention mask to use for training the model.
This module can be used for other applications which define a conditional distribution over next token predictions, such as reinforcement learning.
- Parameters:
emb_dim – The attention embedding dimension.
max_tsz – The maximum number of input tokens.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- causal_mask: Tensor
- curr_mask: Tensor
- forward(prev_embs: Tensor, curr_embs: Tensor) tuple[torch.Tensor, torch.Tensor] [source]
Combines the embeddings to get the transformer inputs.
Note that the tokens for the prev_embs` and
curr_embs
should be the same for each timestep; the only difference is thatprev_embs
should be the ground truth tokens whilecurr_embs
are the outputs of some upstream model, conditioned onprev_embs
padded by one timestep.- Parameters:
prev_embs – The embeddings for the
T - 1
tokens, with shape(B, T, C)
curr_embs – The embeddings for the
T
tokens, with shape(B, T, C)
, which can only attend to the previous embeddings
- Returns:
The inputs to the discriminator transformer, with shape
(B, 2T, C
), and the attention mask, with shape(2T, 2T)