ml.tasks.consistency
Defines the API for training consistency models.
This code largely references the OpenAI implementation, as well as Simo Ryu’s implementation.
# Instantiates the consistency model module.
diff = ConsistencyModel(sigmas)
# The forward pass should take a noisy tensor and the current timestep
# and return the denoised tensor. Can add class conditioning as well but
# you need a function which satisfies this signature.
def forward_pass(x: Tensor, t: Tensor) -> Tensor:
...
# Compute ths loss.
loss = diff.loss(forward_pass, x, state)
loss.sum().backward()
# Sample from the model. Consistency models produce good samples even with
# a small number of steps.
samples = diff.sample(forward_pass, x.shape, x.device, num_steps=4)
- class ml.tasks.consistency.ConsistencyModel(total_steps: int | None = None, sigma_data: float = 0.5, sigma_max: float = 80.0, sigma_min: float = 0.002, rho: float = 7.0, p_mean: float = -1.1, p_std: float = 2.0, start_scales: int = 20, end_scales: int = 1280, loss_type: Literal['mse', 'l1', 'pseudo-huber'] = 'pseudo-huber', loss_dim: int = -1, loss_factor: float = 0.00054)[source]
Bases:
Module
Defines a module which implements consistency diffusion models.
This model introduces an auxiliary consistency penalty to the loss function to encourage the ODE to be smooth, allowing for few-step inference.
This also implements the improvements to vanilla diffusion described in
Elucidating the Design Space of Diffusion-Based Generative Models
.- Parameters:
total_steps – The maximum number of training steps, used for determining the discretization step schedule.
sigma_data – The standard deviation of the data.
sigma_max – The maximum standard deviation for the diffusion process.
sigma_min – The minimum standard deviation for the diffusion process.
rho – The rho constant for the noise schedule.
p_mean – A constant which controls the distribution of timesteps to sample for training. Training biases towards sampling timesteps from the less noisy end of the spectrum to improve convergence.
p_std – Another constant that controls the distribution of timesteps for training, used in conjunction with
p_mean
.start_scales – The number of different discretization scales to use at the start of training. At the start of training, a small number of scales is used to encourage the model to learn more quickly, which is increased over time.
end_scales – The number of different discretization scales to use at the end of training.
loss_dim – The dimension over which to compute the loss. This should typically be the channel dimension.
loss_factor – The factor to use for the pseudo-Huber loss. The default value comes from the Consistency Models improvements paper.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- loss(model: Callable[[Tensor, Tensor], Tensor], x: Tensor, step: int) Tensor [source]
Computes the consistency model loss.
- Parameters:
model – The model forward process, which takes a tensor with the same shape as the input data plus a timestep and returns the predicted noise or target, with shape
(*)
.x – The input data, with shape
(*)
step – The current training step, used to determine the number of discretization steps to use.
- Returns:
The loss for supervising the model.
- partial_sample(model: Callable[[Tensor, Tensor], Tensor], reference_sample: Tensor, start_percent: float, num_steps: int) Tensor [source]
Samples from the model, starting from a given reference sample.
Partial sampling takes a reference sample, adds some noise to it, then denoises the sample using the model. This can be used for doing style transfer, where the reference sample is the source image which the model redirects to look more like some target style.
- Parameters:
model – The model forward process, which takes a tensor with the same shape as the input data plus a timestep and returns the predicted noise or target, with shape
(*)
.reference_sample – The reference sample, with shape
(*)
.start_percent – The percentage of timesteps to start sampling from.
num_steps – The number of sampling steps to use.
- Returns:
The samples, with shape
(num_steps + 1, *)
, with the first sample (i.e.,samples[0]
) as the denoised output and the last sample (i.e.,samples[-1]
) as the reference sample.
- sample(model: Callable[[Tensor, Tensor], Tensor], shape: tuple[int, ...], device: device, num_steps: int) Tensor [source]
Samples from the model.
- Parameters:
model – The model forward process, which takes a tensor with the same shape as the input data plus a timestep and returns the predicted noise or target, with shape
(*)
.shape – The shape of the samples.
device – The device to put the samples on.
num_steps – The number of sampling steps to use.
- Returns:
The samples, with shape
(num_steps + 1, *)
, with the first sample (i.e.,samples[0]
) as the denoised output and the last sample (i.e.,samples[-1]
) as the random noise.