"""Defines a general-purpose UNet model."""
import math
from typing import Sequence, cast
import torch
from torch import Tensor, nn
from ml.models.activations import ActivationType, get_activation
from ml.models.norms import NormType, get_norm_2d
[docs]class PositionalEmbedding(nn.Module):
def __init__(self, dim: int, max_length: int = 10000) -> None:
super().__init__()
self.register_buffer("embedding", self.make_embedding(dim, max_length), persistent=False)
embedding: Tensor
[docs] def forward(self, x: Tensor) -> Tensor:
return self.embedding[x]
[docs] @staticmethod
def make_embedding(dim: int, max_length: int = 10000) -> Tensor:
embedding = torch.zeros(max_length, dim)
position = torch.arange(0, max_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(max_length / 2 / math.pi) / dim))
embedding[:, 0::2] = torch.sin(position * div_term)
embedding[:, 1::2] = torch.cos(position * div_term)
return embedding
[docs]class FFN(nn.Module):
def __init__(self, in_dim: int, embed_dim: int) -> None:
super().__init__()
self.init_embed = nn.Linear(in_dim, embed_dim)
self.time_embed = PositionalEmbedding(embed_dim)
self.model = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, in_dim),
)
[docs] def forward(self, x: Tensor, t: Tensor) -> Tensor:
x = self.init_embed(x)
t = self.time_embed(t)
return self.model(x + t)
[docs]class BasicBlock(nn.Module):
def __init__(
self,
in_c: int,
out_c: int,
embed_c: int | None = None,
act: ActivationType = "relu",
norm: NormType = "batch_affine",
) -> None:
super().__init__()
self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = get_norm_2d(norm, dim=out_c)
self.act1 = get_activation(act)
self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = get_norm_2d(norm, dim=out_c)
self.act2 = get_activation(act)
# Projects input embedding to embedding space.
self.mlp_emb = (
nn.Sequential(
nn.Linear(embed_c, embed_c),
get_activation(act),
nn.Linear(embed_c, out_c),
)
if embed_c is not None
else None
)
# Shortcut for residual connection.
self.shortcut = (
nn.Identity()
if in_c == out_c
else nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_c),
)
)
[docs] def forward(self, x: Tensor, embedding: Tensor | None = None) -> Tensor:
out = self.conv1(x)
out = self.bn1(out)
if embedding is not None:
assert self.mlp_emb is not None, "Embedding was provided but embedding projection is None"
tx = self.mlp_emb(embedding)
out = out + tx[..., None, None]
else:
assert self.mlp_emb is None, "Embedding projection is not None but no embedding was provided"
out = self.act1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.act2(out + self.shortcut(x))
return out
[docs]class SelfAttention2d(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, dropout_prob: float = 0.1) -> None:
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.q_conv = nn.Conv2d(dim, dim, 1, bias=True)
self.k_conv = nn.Conv2d(dim, dim, 1, bias=True)
self.v_conv = nn.Conv2d(dim, dim, 1, bias=True)
self.o_conv = nn.Conv2d(dim, dim, 1, bias=True)
self.dropout = nn.Dropout(dropout_prob)
[docs] def forward(self, x: Tensor) -> Tensor:
_, _, h, w = x.shape
q = self.q_conv.forward(x)
k = self.k_conv.forward(x)
v = self.v_conv.forward(x)
q = q.unflatten(1, (self.num_heads, self.dim // self.num_heads)).flatten(-2)
k = k.unflatten(1, (self.num_heads, self.dim // self.num_heads)).flatten(-2)
v = v.unflatten(1, (self.num_heads, self.dim // self.num_heads)).flatten(-2)
a = torch.einsum("... c q, ... c k -> ... q k", q, k) / self.dim**0.5
a = self.dropout(torch.softmax(a, dim=-1))
o = torch.einsum("... s t, ... c t -> ... c s", a, v)
o = o.flatten(1, 2).unflatten(-1, (h, w))
return o
[docs]class UNet(nn.Module):
"""Defines a general-purpose UNet model.
Parameters:
in_dim: Number of input dimensions.
embed_dim: Embedding dimension.
dim_scales: List of dimension scales.
input_embedding_dim: The input embedding dimension, if an input
embedding is used (for example, when conditioning on time, or some
class embedding).
Inputs:
x: Input tensor of shape ``(batch_size, in_dim, height, width)``.
t: Time tensor of shape ``(batch_size)`` if ``use_time`` is ``True``
and ``None`` otherwise.
c: Class tensor of shape ``(batch_size, class_dim)`` if ``use_class``
is ``True`` and ``None`` otherwise.
Outputs:
x: Output tensor of shape ``(batch_size, in_dim, height, width)``.
"""
def __init__(
self,
in_dim: int,
embed_dim: int,
dim_scales: Sequence[int],
input_embedding_dim: int | None = None,
) -> None:
super().__init__()
self.init_embed = nn.Conv2d(in_dim, embed_dim, 1)
self.down_blocks = cast(list[BasicBlock | nn.Conv2d], nn.ModuleList())
self.up_blocks = cast(list[BasicBlock | nn.ConvTranspose2d], nn.ModuleList())
all_dims = (embed_dim, *[embed_dim * s for s in dim_scales])
for idx, (in_c, out_c) in enumerate(zip(all_dims[:-1], all_dims[1:])):
is_last = idx == len(all_dims) - 2
self.down_blocks.extend(
nn.ModuleList(
[
BasicBlock(in_c, in_c, input_embedding_dim),
BasicBlock(in_c, in_c, input_embedding_dim),
nn.Conv2d(in_c, out_c, 3, 2, 1) if not is_last else nn.Conv2d(in_c, out_c, 1),
]
)
)
for idx, (in_c, out_c, skip_c) in enumerate(zip(all_dims[::-1][:-1], all_dims[::-1][1:], all_dims[:-1][::-1])):
is_last = idx == len(all_dims) - 2
self.up_blocks.extend(
nn.ModuleList(
[
BasicBlock(in_c + skip_c, in_c, input_embedding_dim),
BasicBlock(in_c + skip_c, in_c, input_embedding_dim),
nn.ConvTranspose2d(in_c, out_c, (2, 2), 2) if not is_last else nn.Conv2d(in_c, out_c, 1),
]
)
)
self.mid_blocks = cast(
list[BasicBlock | SelfAttention2d],
nn.ModuleList(
[
BasicBlock(all_dims[-1], all_dims[-1], input_embedding_dim),
SelfAttention2d(all_dims[-1]),
BasicBlock(all_dims[-1], all_dims[-1], input_embedding_dim),
]
),
)
self.out_blocks = cast(
list[BasicBlock | nn.Conv2d],
nn.ModuleList(
[
BasicBlock(embed_dim, embed_dim, input_embedding_dim),
nn.Conv2d(embed_dim, in_dim, 1, bias=True),
]
),
)
[docs] def forward(self, x: Tensor, embedding: Tensor | None = None) -> Tensor:
x = self.init_embed(x)
skip_conns = []
residual = x.clone()
for block in self.down_blocks:
if isinstance(block, BasicBlock):
x = block.forward(x, embedding)
skip_conns.append(x)
else:
x = block.forward(x)
for block in self.mid_blocks:
if isinstance(block, BasicBlock):
x = block.forward(x, embedding)
else:
x = block.forward(x)
for block in self.up_blocks:
if isinstance(block, BasicBlock):
x = torch.cat((x, skip_conns.pop()), dim=1)
x = block.forward(x, embedding)
else:
x = block.forward(x)
x = x + residual
for block in self.out_blocks:
if isinstance(block, BasicBlock):
x = block.forward(x, embedding)
else:
x = block.forward(x)
return x