ml.models.embeddings

Defines a general-purpose API for transformer embedding layers.

from ml.models.embeddings import get_positional_embeddings, cast_embedding_kind

embeddings = get_positional_embeddings(
    max_tsz=1024,
    embed_dim=128,
    kind="sinusoidal",
    learnable=False,
)

x = torch.arange(3, 5, 8)

# Time-based positional embeddings - the time tensor supplies the
# times for each element in the input.
times = torch.randint(0, 1024, (3, 5))
y1 = embeddings(x, times=times)

# Offset-based positional embeddings - the input is assumed to be in
# temporal order, and the offset is the offset of the first element.
y2 = embeddings(x, offset=1)

assert y1.shape == y2.shape == x.shape

# This lets you parametrize the embedding kind as a string.
embeddings = get_positional_embeddings(..., kind=cast_embedding_kind(my_kind))

Choices for the embedding kind are:

  • "identity": No positional embeddings are added.

  • "learned": Positional embeddings are learned.

  • "sinusoidal": Sinusoidal embeddings.

  • "rotary": Rotary embeddings (popular for training transformers).

ml.models.embeddings.cast_embedding_kind(k: str) Literal['identity', 'learned', 'sinusoidal', 'rotary'][source]
class ml.models.embeddings.IdentityPositionalEmbeddings(*args, **kwargs)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, offset: int = 0, times: Tensor | None = None) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ml.models.embeddings.LearnedPositionalEmbeddings(max_tsz: int, embed_dim: int, weight_init: Literal['orthogonal', 'normal', 'biased_normal', 'uniform', 'kaiming_uniform', 'kaiming_normal', 'xavier_uniform', 'xavier_normal', 'trunc_normal', 'dirac', 'constant', 'zeros', 'ones'] = 'normal', learnable: bool = True)[source]

Bases: Module

Defines a learned embeddings module.

Parameters:
  • max_tsz – The maximum sequence length.

  • embed_dim – The embedding dimension.

  • weight_init – The initialization type for the embedding weight.

  • learnable – Whether the embeddings are learnable.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

reset_parameters() None[source]
forward(x: Tensor, offset: int = 0, times: Tensor | None = None) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ml.models.embeddings.SinusoidalEmbeddings(embed_dim: int | None = None, max_tsz: int | None = None, learnable: bool = True, base: int = 10000)[source]

Bases: Module

Defines a sinusoidal embeddings module.

Parameters:
  • embed_dim – The embedding dimension.

  • max_tsz – The maximum sequence length.

  • learnable – Whether the embeddings are learnable.

  • base – The base for the sinusoidal embeddings.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, offset: int = 0, times: Tensor | None = None) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters() None[source]
get_embeddings(tsz: int, embed_dim: int, device: device | None = None, dtype: dtype | None = None) Tensor[source]
ml.models.embeddings.get_rotary_embeddings(tsz: int, embed_dim: int, device: device, dtype: dtype, offset: int = 0, base: int = 10000) Tensor[source]
ml.models.embeddings.apply_rotary_embeddings(x: Tensor, embs: Tensor, offset: int = 0, times: Tensor | None = None) Tensor[source]
ml.models.embeddings.rotary_embeddings(x: Tensor, offset: int = 0, base: int = 10000) Tensor[source]

Defines a single function for applying rotary embeddings.

This is slower than using the module, but it doesn’t require pre-initializing the embeddings, so it can be used when running online.

Parameters:
  • x – The input tensor.

  • offset – The offset for the first element.

  • base – The base for the sinusoidal embeddings.

Returns:

The input tensor with rotary embeddings applied.

class ml.models.embeddings.RotaryEmbeddings(base: int = 10000)[source]

Bases: Module

Defines a rotary embeddings module.

Parameters:

base – The base for the sinusoidal embeddings.

forward(x: Tensor, offset: int = 0, times: Tensor | None = None) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

ml.models.embeddings.get_positional_embeddings(kind: Literal['identity']) IdentityPositionalEmbeddings[source]
ml.models.embeddings.get_positional_embeddings(kind: Literal['learned'], *, max_tsz: int, embed_dim: int, weight_init: Literal['orthogonal', 'normal', 'biased_normal', 'uniform', 'kaiming_uniform', 'kaiming_normal', 'xavier_uniform', 'xavier_normal', 'trunc_normal', 'dirac', 'constant', 'zeros', 'ones'] = 'normal', learnable: bool | None = None) LearnedPositionalEmbeddings
ml.models.embeddings.get_positional_embeddings(kind: Literal['sinusoidal'], *, max_tsz: int | None = None, embed_dim: int | None = None, learnable: bool | None = None, base: int = 10000) SinusoidalEmbeddings
ml.models.embeddings.get_positional_embeddings(kind: Literal['rotary'], *, base: int = 10000) RotaryEmbeddings
ml.models.embeddings.get_positional_embeddings(kind: Literal['identity', 'learned', 'sinusoidal', 'rotary'], *, max_tsz: int | None = None, embed_dim: int | None = None, weight_init: Literal['orthogonal', 'normal', 'biased_normal', 'uniform', 'kaiming_uniform', 'kaiming_normal', 'xavier_uniform', 'xavier_normal', 'trunc_normal', 'dirac', 'constant', 'zeros', 'ones'] = 'normal', learnable: bool | None = None, base: int = 10000) IdentityPositionalEmbeddings | LearnedPositionalEmbeddings | SinusoidalEmbeddings | RotaryEmbeddings

Defines the common module for adding positional embeddings.

Parameters:
  • kind – The type of embedding to use.

  • max_tsz – The maximum sequence length.

  • embed_dim – The embedding dimension.

  • weight_init – The weight initialization for learned embeddings.

  • learnable – Whether the embeddings are learnable; if not provided, uses sensible defaults.

  • base – The base for the sinusoidal embeddings.

Returns:

The positional embeddings module.

Raises:

ValueError – If an invalid embedding kind is supplied.

ml.models.embeddings.fourier_embeddings(t: Tensor, dim: int, max_period: int = 10000) Tensor[source]
class ml.models.embeddings.FourierEmbeddings(dim: int, max_period: int = 10000)[source]

Bases: Module

Defines a module for applying Fourier embeddings to timesteps.

This module differs from the other positional embedding modules because it expects a continuous time input, rather than a discrete time input.

Parameters:
  • dim – The number of embedding dimensions. This value is used to determine how many different frequencies to use, and a higher value means higher frequencies.

  • max_period – The maximum period for the embeddings. This should roughly be in line with the maximum number of timesteps; the default value of 10,000 is commonly used in NLP applications, and is derived from operating on sequence lengths of 100 to 1000 tokens.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(t: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.