"""Defines the API for training consistency models.
This code largely references the `OpenAI implementation <https://github.com/openai/consistency_models>`_,
as well as `Simo Ryu's implementation <https://github.com/cloneofsimo/consistency_models/tree/master>`_.
.. code-block:: python
# Instantiates the consistency model module.
diff = ConsistencyModel(sigmas)
# The forward pass should take a noisy tensor and the current timestep
# and return the denoised tensor. Can add class conditioning as well but
# you need a function which satisfies this signature.
def forward_pass(x: Tensor, t: Tensor) -> Tensor:
...
# Compute ths loss.
loss = diff.loss(forward_pass, x, state)
loss.sum().backward()
# Sample from the model. Consistency models produce good samples even with
# a small number of steps.
samples = diff.sample(forward_pass, x.shape, x.device, num_steps=4)
"""
import math
from typing import Callable, Literal
import torch
from torch import Tensor, nn
from ml.tasks.losses.loss import loss_fn
from ml.utils.ops import append_dims
[docs]class ConsistencyModel(nn.Module):
"""Defines a module which implements consistency diffusion models.
This model introduces an auxiliary consistency penalty to the loss function
to encourage the ODE to be smooth, allowing for few-step inference.
This also implements the improvements to vanilla diffusion described in
``Elucidating the Design Space of Diffusion-Based Generative Models``.
Parameters:
total_steps: The maximum number of training steps, used for determining
the discretization step schedule.
sigma_data: The standard deviation of the data.
sigma_max: The maximum standard deviation for the diffusion process.
sigma_min: The minimum standard deviation for the diffusion process.
rho: The rho constant for the noise schedule.
p_mean: A constant which controls the distribution of timesteps to
sample for training. Training biases towards sampling timesteps
from the less noisy end of the spectrum to improve convergence.
p_std: Another constant that controls the distribution of timesteps for
training, used in conjunction with ``p_mean``.
start_scales: The number of different discretization scales to use at
the start of training. At the start of training, a small number of
scales is used to encourage the model to learn more quickly, which
is increased over time.
end_scales: The number of different discretization scales to use at the
end of training.
loss_dim: The dimension over which to compute the loss. This should
typically be the channel dimension.
loss_factor: The factor to use for the pseudo-Huber loss. The default
value comes from the Consistency Models improvements paper.
"""
__constants__ = [
"total_steps",
"sigma_data",
"sigma_max",
"sigma_min",
"rho",
"p_mean",
"p_std",
"start_scales",
"end_scales",
]
def __init__(
self,
total_steps: int | None = None,
sigma_data: float = 0.5,
sigma_max: float = 80.0,
sigma_min: float = 0.002,
rho: float = 7.0,
p_mean: float = -1.1,
p_std: float = 2.0,
start_scales: int = 20,
end_scales: int = 1280,
loss_type: Literal["mse", "l1", "pseudo-huber"] = "pseudo-huber",
loss_dim: int = -1,
loss_factor: float = 0.00054,
) -> None:
super().__init__()
self.total_steps = total_steps
self.sigma_data = sigma_data
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.rho = rho
self.p_mean = p_mean
self.p_std = p_std
self.start_scales = start_scales
self.end_scales = end_scales
def get_loss_fn() -> Callable[[Tensor, Tensor], Tensor]:
match loss_type:
case "mse" | "l1":
return loss_fn(loss_type)
case "pseudo-huber":
return loss_fn(loss_type, dim=loss_dim, pseudo_huber_factor=loss_factor)
case _:
raise ValueError(f"Invalid loss type: {loss_type}")
self.loss_fn = get_loss_fn()
[docs] def loss(self, model: Callable[[Tensor, Tensor], Tensor], x: Tensor, step: int) -> Tensor:
"""Computes the consistency model loss.
Args:
model: The model forward process, which takes a tensor with the
same shape as the input data plus a timestep and returns the
predicted noise or target, with shape ``(*)``.
x: The input data, with shape ``(*)``
step: The current training step, used to determine the number of
discretization steps to use.
Returns:
The loss for supervising the model.
"""
dims = x.ndim
# The number of discretization steps runs on a schedule.
num_scales = self._get_num_scales(step)
# Rather than randomly sampling some timesteps for training, we bias the
# samples to be closer to the less noisy end, which improves training
# stability. This distribution is defined as a function of the standard
# deviations.
# timesteps = torch.randint(0, num_scales - 1, (x.shape[0],), device=x.device)
timesteps = self._sample_timesteps(x, num_scales)
t_current, t_next = timesteps / (num_scales - 1), (timesteps + 1) / (num_scales - 1)
# Converts timesteps to sigmas.
sigma_next, sigma_current = self._get_sigmas(torch.stack((t_next, t_current))).unbind(0)
noise = torch.randn_like(x)
dropout_state = torch.get_rng_state()
x_current = x + noise * append_dims(sigma_current, dims)
y_current = self._call_model(model, x_current, sigma_current)
# Resets the dropout state and runs the target model.
torch.set_rng_state(dropout_state)
with torch.no_grad():
x_next = x + noise * append_dims(sigma_next, dims)
y_next = self._call_model(model, x_next, sigma_next).detach()
loss = self._get_loss(y_current, y_next, sigma_next, sigma_current)
return loss
[docs] @torch.no_grad()
def partial_sample(
self,
model: Callable[[Tensor, Tensor], Tensor],
reference_sample: Tensor,
start_percent: float,
num_steps: int,
) -> Tensor:
"""Samples from the model, starting from a given reference sample.
Partial sampling takes a reference sample, adds some noise to it, then
denoises the sample using the model. This can be used for doing
style transfer, where the reference sample is the source image which
the model redirects to look more like some target style.
Args:
model: The model forward process, which takes a tensor with the
same shape as the input data plus a timestep and returns the
predicted noise or target, with shape ``(*)``.
reference_sample: The reference sample, with shape ``(*)``.
start_percent: The percentage of timesteps to start sampling from.
num_steps: The number of sampling steps to use.
Returns:
The samples, with shape ``(num_steps + 1, *)``, with the first
sample (i.e., ``samples[0]``) as the denoised output and the last
sample (i.e., ``samples[-1]``) as the reference sample.
"""
assert 0.0 <= start_percent <= 1.0
device = reference_sample.device
timesteps = torch.linspace(start_percent, 1, num_steps + 1, device=device, dtype=torch.float32)
sigmas = self._get_sigmas(timesteps)
x = reference_sample
x = x + torch.randn_like(x) * sigmas[None, 0]
samples = torch.empty((num_steps + 1, *x.shape), device=x.device)
samples[num_steps] = x
for i in range(num_steps):
x = self._call_model(model, x, sigmas[None, i])
samples[num_steps - 1 - i] = x
if i < num_steps - 1:
x = x + torch.randn_like(x) * sigmas[None, i + 1]
return samples
[docs] @torch.no_grad()
def sample(
self,
model: Callable[[Tensor, Tensor], Tensor],
shape: tuple[int, ...],
device: torch.device,
num_steps: int,
) -> Tensor:
"""Samples from the model.
Args:
model: The model forward process, which takes a tensor with the
same shape as the input data plus a timestep and returns the
predicted noise or target, with shape ``(*)``.
shape: The shape of the samples.
device: The device to put the samples on.
num_steps: The number of sampling steps to use.
Returns:
The samples, with shape ``(num_steps + 1, *)``, with the first
sample (i.e., ``samples[0]``) as the denoised output and the last
sample (i.e., ``samples[-1]``) as the random noise.
"""
timesteps = torch.linspace(0, 1, num_steps + 1, device=device, dtype=torch.float32)
sigmas = self._get_sigmas(timesteps)
x = torch.randn(shape, device=device) * sigmas[0]
samples = torch.empty((num_steps + 1, *x.shape), device=x.device)
samples[num_steps] = x
for i in range(num_steps):
x = self._call_model(model, x, sigmas[None, i])
samples[num_steps - 1 - i] = x
if i < num_steps - 1:
x = x + torch.randn_like(x) * sigmas[None, i + 1]
return samples
def _get_scalings(self, sigma: Tensor) -> tuple[Tensor, Tensor, Tensor]:
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out, c_in
def _call_model(self, model: Callable[[Tensor, Tensor], Tensor], x_t: Tensor, sigmas: Tensor) -> Tensor:
c_skip, c_out, c_in = (append_dims(x, x_t.ndim) for x in self._get_scalings(sigmas))
timesteps = 1000 * 0.25 * torch.log(sigmas + 1e-44)
model_output = model(c_in * x_t, timesteps)
denoised = c_out * model_output + c_skip * x_t
return denoised
def _get_loss(self, y_hat: Tensor, y: Tensor, sigma_next: Tensor, sigma_current: Tensor) -> Tensor:
weights = 1 / (sigma_current - sigma_next)
loss = self.loss_fn(y_hat, y)
weights = weights.view(-1, *([1] * (loss.dim() - 1)))
return loss * weights
@torch.no_grad()
def _get_sigmas(self, timesteps: Tensor) -> Tensor:
min_inv_rho = self.sigma_min ** (1 / self.rho)
max_inv_rho = self.sigma_max ** (1 / self.rho)
sigmas: Tensor = (max_inv_rho + timesteps * (min_inv_rho - max_inv_rho)) ** self.rho
sigmas[timesteps == 1.0] = 0.0
return sigmas
def _get_noise_distribution(self, sigma_next: Tensor, sigma_current: Tensor) -> Tensor:
denom = math.sqrt(2) * self.p_std
lhs = torch.erf((torch.log(sigma_next) - self.p_mean) / denom)
rhs = torch.erf((torch.log(sigma_current) - self.p_mean) / denom)
# return lhs - rhs
return rhs - lhs
def _sample_timesteps(self, x: Tensor, num_scales: int) -> Tensor:
timesteps = torch.linspace(0, 1, num_scales, device=x.device, dtype=torch.float32)
sigmas = self._get_sigmas(timesteps)
noise_dist = self._get_noise_distribution(sigmas[1:], sigmas[:-1])
timesteps = torch.multinomial(noise_dist, x.shape[0], replacement=True)
return timesteps
def _get_num_scales(self, step: int) -> int:
if self.total_steps is None:
return self.end_scales + 1
num_steps = min(self.total_steps, step)
k_prime = math.floor(self.total_steps / (math.log2(self.end_scales / self.start_scales) + 1))
return min(self.start_scales * 2 ** math.floor(num_steps / k_prime), self.end_scales) + 1