"""Defines a general-purpose API for transformer embedding layers.
.. highlight:: python
.. code-block:: python
from ml.models.embeddings import get_positional_embeddings, cast_embedding_kind
embeddings = get_positional_embeddings(
max_tsz=1024,
embed_dim=128,
kind="sinusoidal",
learnable=False,
)
x = torch.arange(3, 5, 8)
# Time-based positional embeddings - the time tensor supplies the
# times for each element in the input.
times = torch.randint(0, 1024, (3, 5))
y1 = embeddings(x, times=times)
# Offset-based positional embeddings - the input is assumed to be in
# temporal order, and the offset is the offset of the first element.
y2 = embeddings(x, offset=1)
assert y1.shape == y2.shape == x.shape
# This lets you parametrize the embedding kind as a string.
embeddings = get_positional_embeddings(..., kind=cast_embedding_kind(my_kind))
Choices for the embedding kind are:
- ``"identity"``: No positional embeddings are added.
- ``"learned"``: Positional embeddings are learned.
- ``"sinusoidal"``: Sinusoidal embeddings.
- ``"rotary"``: Rotary embeddings (popular for training transformers).
"""
import math
from typing import Literal, cast, get_args, overload
import torch
from torch import Tensor, nn
from ml.models.init import InitializationType, init_
EmbeddingKind = Literal["identity", "learned", "sinusoidal", "rotary"]
[docs]def cast_embedding_kind(k: str) -> EmbeddingKind:
args = get_args(EmbeddingKind)
assert k in args, f"Invalid initialization type: '{k}' Valid options are {args}"
return cast(EmbeddingKind, k)
[docs]class IdentityPositionalEmbeddings(nn.Module):
[docs] def forward(self, x: Tensor, offset: int = 0, times: Tensor | None = None) -> Tensor:
return x
[docs]class LearnedPositionalEmbeddings(nn.Module):
"""Defines a learned embeddings module.
Parameters:
max_tsz: The maximum sequence length.
embed_dim: The embedding dimension.
weight_init: The initialization type for the embedding weight.
learnable: Whether the embeddings are learnable.
"""
def __init__(
self,
max_tsz: int,
embed_dim: int,
weight_init: InitializationType = "normal",
learnable: bool = True,
) -> None:
super().__init__()
self.max_tsz = max_tsz
self.embed_dim = embed_dim
self.weight_init = weight_init
self.embeddings = nn.Parameter(torch.empty(max_tsz, embed_dim), requires_grad=learnable)
self.reset_parameters()
[docs] def reset_parameters(self) -> None:
init_(self.embeddings.data, None, self.weight_init)
[docs] def forward(self, x: Tensor, offset: int = 0, times: Tensor | None = None) -> Tensor:
return x + (self.embeddings[None, offset : offset + x.size(1)] if times is None else self.embeddings[times])
[docs]class SinusoidalEmbeddings(nn.Module):
"""Defines a sinusoidal embeddings module.
Parameters:
embed_dim: The embedding dimension.
max_tsz: The maximum sequence length.
learnable: Whether the embeddings are learnable.
base: The base for the sinusoidal embeddings.
"""
def __init__(
self,
embed_dim: int | None = None,
max_tsz: int | None = None,
learnable: bool = True,
base: int = 10_000,
) -> None:
super().__init__()
self.max_tsz = max_tsz
self.embed_dim = embed_dim
self.base = base
self.embeddings: nn.Parameter | None = None
if learnable:
assert max_tsz is not None, "Learnable parameters require `max_tsz` to be set"
assert embed_dim is not None, "Learnable parameters require `embed_dim` to be set"
self.embeddings = nn.Parameter(torch.empty(max_tsz, embed_dim), requires_grad=learnable)
self.reset_parameters()
self.embeddings_cached: Tensor | None = None
[docs] def forward(self, x: Tensor, offset: int = 0, times: Tensor | None = None) -> Tensor:
embeddings: Tensor | None = self.embeddings
_, tsz, xdim = x.shape
if embeddings is None:
max_tsz = max(tsz, 0 if times is None else int(times.max().item()) + 1) + offset
if self.embeddings_cached is None:
self.embeddings_cached = self.get_embeddings(max_tsz, xdim, x.device, x.dtype)
else:
embed_tsz, embed_dim = self.embeddings_cached.shape
embed_device, embed_dtype = self.embeddings_cached.device, self.embeddings_cached.dtype
if embed_tsz < max_tsz or embed_dim != xdim or embed_device != x.device or embed_dtype != x.dtype:
self.embeddings_cached = self.get_embeddings(max_tsz, embed_dim, x.device, x.dtype)
embeddings = self.embeddings_cached
return x + (embeddings[None, offset : offset + tsz] if times is None else embeddings[times])
[docs] def reset_parameters(self) -> None:
if self.embeddings is None:
assert self.max_tsz is not None, "Learnable parameters require `max_tsz` to be set"
assert self.embed_dim is not None, "Learnable parameters require `embed_dim` to be set"
self.embeddings.data.copy_(self.get_embeddings(self.max_tsz, self.embed_dim))
[docs] def get_embeddings(
self,
tsz: int,
embed_dim: int,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Tensor:
positions = torch.arange(tsz, device=device, dtype=torch.float32)
dim = torch.arange(embed_dim, device=device, dtype=torch.float32)
dim = self.base ** (2 * (dim // 2) / embed_dim)
embeddings = positions[:, None] / dim[None, :]
embeddings[:, 0::2] = torch.sin(embeddings[:, 0::2])
embeddings[:, 1::2] = torch.cos(embeddings[:, 1::2])
return embeddings.to(dtype)
[docs]@torch.no_grad()
def get_rotary_embeddings(
tsz: int,
embed_dim: int,
device: torch.device,
dtype: torch.dtype,
offset: int = 0,
base: int = 10_000,
) -> Tensor:
assert embed_dim % 4 == 0, f"Embedding dimension must be divisible by 4, got {embed_dim}"
half_d = embed_dim // 2
theta = 1.0 / (base ** (torch.arange(0, half_d, 2, device=device, dtype=torch.float32) / half_d))
seq_idx = torch.arange(offset, tsz + offset, device=device, dtype=torch.float32)
idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
cos, sin = idx_theta2.cos(), idx_theta2.sin()
return torch.stack((cos, sin), dim=0).to(dtype)
[docs]def apply_rotary_embeddings(x: Tensor, embs: Tensor, offset: int = 0, times: Tensor | None = None) -> Tensor:
cos, sin = embs.unbind(0)
_, tsz, embed_dim = x.shape
half_d = embed_dim // 2
quarter_d = embed_dim // 4
x_rope, x_pass = x[..., :half_d], x[..., half_d:]
neg_half_x = torch.cat([-x_rope[..., quarter_d:], x_rope[..., :quarter_d]], dim=-1)
cos_part = cos[None, offset : offset + tsz] if times is None else cos[times]
sin_part = sin[None, offset : offset + tsz] if times is None else sin[times]
x_rope = x_rope * cos_part + neg_half_x * sin_part
return torch.cat((x_rope, x_pass), dim=-1)
[docs]def rotary_embeddings(x: Tensor, offset: int = 0, base: int = 10_000) -> Tensor:
"""Defines a single function for applying rotary embeddings.
This is slower than using the module, but it doesn't require
pre-initializing the embeddings, so it can be used when running online.
Args:
x: The input tensor.
offset: The offset for the first element.
base: The base for the sinusoidal embeddings.
Returns:
The input tensor with rotary embeddings applied.
"""
(_, tsz, embed_dim), device, dtype = x.shape, x.device, x.dtype
embeddings = get_rotary_embeddings(tsz + offset, embed_dim, device, dtype, 0, base)
return apply_rotary_embeddings(x, embeddings, offset)
[docs]class RotaryEmbeddings(nn.Module):
def __init__(self, base: int = 10_000) -> None:
"""Defines a rotary embeddings module.
Args:
base: The base for the sinusoidal embeddings.
"""
super().__init__()
self.base = base
self.embeddings: Tensor | None = None
[docs] def forward(self, x: Tensor, offset: int = 0, times: Tensor | None = None) -> Tensor:
embeddings = self.embeddings
_, tsz, embed_dim = x.shape
max_tsz = max(tsz, 0 if times is None else int(times.max().item()) + 1) + offset
if embeddings is None or embeddings.shape[-2] < max_tsz:
embeddings = self.embeddings = get_rotary_embeddings(max_tsz, embed_dim, x.device, x.dtype, 0, self.base)
return apply_rotary_embeddings(x, embeddings, offset, times)
@overload
def get_positional_embeddings(kind: Literal["identity"]) -> IdentityPositionalEmbeddings:
...
@overload
def get_positional_embeddings(
kind: Literal["learned"],
*,
max_tsz: int,
embed_dim: int,
weight_init: InitializationType = "normal",
learnable: bool | None = None,
) -> LearnedPositionalEmbeddings:
...
@overload
def get_positional_embeddings(
kind: Literal["sinusoidal"],
*,
max_tsz: int | None = None,
embed_dim: int | None = None,
learnable: bool | None = None,
base: int = 10_000,
) -> SinusoidalEmbeddings:
...
@overload
def get_positional_embeddings(
kind: Literal["rotary"],
*,
base: int = 10_000,
) -> RotaryEmbeddings:
...
@overload
def get_positional_embeddings(
kind: EmbeddingKind,
*,
max_tsz: int | None = None,
embed_dim: int | None = None,
weight_init: InitializationType = "normal",
learnable: bool | None = None,
base: int = 10_000,
) -> IdentityPositionalEmbeddings | LearnedPositionalEmbeddings | SinusoidalEmbeddings | RotaryEmbeddings:
...
[docs]def get_positional_embeddings(
kind: EmbeddingKind,
*,
max_tsz: int | None = None,
embed_dim: int | None = None,
weight_init: InitializationType = "normal",
learnable: bool | None = None,
base: int = 10_000,
) -> nn.Module:
"""Defines the common module for adding positional embeddings.
Args:
kind: The type of embedding to use.
max_tsz: The maximum sequence length.
embed_dim: The embedding dimension.
weight_init: The weight initialization for learned embeddings.
learnable: Whether the embeddings are learnable; if not provided,
uses sensible defaults.
base: The base for the sinusoidal embeddings.
Returns:
The positional embeddings module.
Raises:
ValueError: If an invalid embedding kind is supplied.
"""
match kind:
case "identity":
return IdentityPositionalEmbeddings()
case "learned":
assert max_tsz is not None, "Learned embeddings require `max_tsz` to be set"
assert embed_dim is not None, "Learned embeddings require `embed_dim` to be set"
return LearnedPositionalEmbeddings(
max_tsz=max_tsz,
embed_dim=embed_dim,
weight_init=weight_init,
learnable=True if learnable is None else learnable,
)
case "sinusoidal":
return SinusoidalEmbeddings(
max_tsz=max_tsz,
embed_dim=embed_dim,
learnable=False if learnable is None else learnable,
base=base,
)
case "rotary":
return RotaryEmbeddings(base=base)
case _:
raise ValueError(f"Invalid embedding kind: {kind}")
[docs]def fourier_embeddings(t: Tensor, dim: int, max_period: int = 10000) -> Tensor:
half = dim // 2
idxs = torch.arange(start=0, end=half, device=t.device, dtype=torch.float32)
freqs = torch.exp(-math.log(max_period) * idxs / half)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# Adds an additional row of zeros to match the expected dimension.
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
[docs]class FourierEmbeddings(nn.Module):
"""Defines a module for applying Fourier embeddings to timesteps.
This module differs from the other positional embedding modules because it
expects a continuous time input, rather than a discrete time input.
Parameters:
dim: The number of embedding dimensions. This value is used to determine
how many different frequencies to use, and a higher value means
higher frequencies.
max_period: The maximum period for the embeddings. This should roughly
be in line with the maximum number of timesteps; the default value
of 10,000 is commonly used in NLP applications, and is derived from
operating on sequence lengths of 100 to 1000 tokens.
"""
__constants__ = ["dim", "max_period"]
def __init__(self, dim: int, max_period: int = 10000) -> None:
super().__init__()
self.dim = dim
self.max_period = max_period
[docs] def forward(self, t: Tensor) -> Tensor:
return fourier_embeddings(t, self.dim, self.max_period)