ml.models.architectures.unet
Defines a general-purpose UNet model.
- class ml.models.architectures.unet.PositionalEmbedding(dim: int, max_length: int = 10000)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- embedding: Tensor
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class ml.models.architectures.unet.FFN(in_dim: int, embed_dim: int)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, t: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class ml.models.architectures.unet.BasicBlock(in_c: int, out_c: int, embed_c: int | None = None, act: Literal['no_act', 'relu', 'relu6', 'relu2', 'clamp6', 'leaky_relu', 'elu', 'celu', 'selu', 'gelu', 'gelu_fast', 'sigmoid', 'log_sigmoid', 'hard_sigomid', 'tanh', 'softsign', 'softplus', 'silu', 'mish', 'swish', 'hard_swish', 'soft_shrink', 'hard_shrink', 'tanh_shrink', 'soft_sign', 'relu_squared', 'laplace'] = 'relu', norm: Literal['no_norm', 'batch', 'batch_affine', 'instance', 'instance_affine', 'group', 'group_affine', 'layer', 'layer_affine'] = 'batch_affine')[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, embedding: Tensor | None = None) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class ml.models.architectures.unet.SelfAttention2d(dim: int, num_heads: int = 8, dropout_prob: float = 0.1)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class ml.models.architectures.unet.UNet(in_dim: int, embed_dim: int, dim_scales: Sequence[int], input_embedding_dim: int | None = None)[source]
Bases:
Module
Defines a general-purpose UNet model.
- Parameters:
in_dim – Number of input dimensions.
embed_dim – Embedding dimension.
dim_scales – List of dimension scales.
input_embedding_dim – The input embedding dimension, if an input embedding is used (for example, when conditioning on time, or some class embedding).
- Inputs:
x: Input tensor of shape
(batch_size, in_dim, height, width)
. t: Time tensor of shape(batch_size)
ifuse_time
isTrue
and
None
otherwise.- c: Class tensor of shape
(batch_size, class_dim)
ifuse_class
is
True
andNone
otherwise.
- c: Class tensor of shape
- Outputs:
x: Output tensor of shape
(batch_size, in_dim, height, width)
.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, embedding: Tensor | None = None) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.