ml.tasks.diffusion

Defines the API for Gaussian diffusion.

This is largely take from here.

This module can be used to train a Gaussian diffusion model as follows.

# Instantiate the beta schedule and diffusion module.
diff = GaussianDiffusion()

# Pseudo-training loop.
for _ in range(1000):
    images = ds[index]  # Get some image from the dataset
    loss = diff.loss(images, model)
    loss.backward()
    optimizer.step()

# Sample from the model.
init_noise = torch.randn_like(images)
generated = diff.sample(model, init_noise)
show_image(generated[-1])

Choices for the beta schedule are:

  • "linear": Linearly increasing beta.

  • "quad": Quadratically increasing beta.

  • "warmup": Linearly increasing beta with a warmup period.

  • "const": Constant beta.

  • "cosine": Cosine annealing schedule.

  • "jsd": Jensen-Shannon divergence schedule.

ml.tasks.diffusion.cast_beta_schedule(schedule: str) Literal['linear', 'quad', 'warmup', 'const', 'cosine', 'jsd'][source]
ml.tasks.diffusion.get_diffusion_beta_schedule(schedule: Literal['linear', 'quad', 'warmup', 'const', 'cosine', 'jsd'], num_timesteps: int, *, beta_start: float = 0.0001, beta_end: float = 0.02, warmup: float = 0.1, cosine_offset: float = 0.008, dtype: dtype = torch.float32) Tensor[source]

Returns a beta schedule for the given schedule type.

Parameters:
  • schedule – The schedule type.

  • num_timesteps – The total number of timesteps.

  • beta_start – The initial beta value, for linear, quad, and warmup schedules.

  • beta_end – The final beta value, for linear, quad, warmup and const schedules.

  • warmup – The fraction of timesteps to use for the warmup schedule (between 0 and 1).

  • cosine_offset – The cosine offset, for cosine schedules.

  • dtype – The dtype of the returned tensor.

Returns:

The beta schedule, a tensor with shape (num_timesteps).

class ml.tasks.diffusion.GaussianDiffusion(beta_schedule: Literal['linear', 'quad', 'warmup', 'const', 'cosine', 'jsd'] = 'linear', num_beta_steps: int = 1000, pred_mode: Literal['pred_x_0', 'pred_eps', 'pred_v'] = 'pred_x_0', loss: Literal['mse', 'l1', 'pseudo-huber'] = 'mse', sigma_type: Literal['upper_bound', 'lower_bound'] = 'upper_bound', solver: Literal['euler', 'heun', 'rk4'] = 'euler', *, beta_start: float = 0.0001, beta_end: float = 0.02, warmup: float = 0.1, cosine_offset: float = 0.008)[source]

Bases: Module

Defines a module which provides utility functions for Gaussian diffusion.

Parameters:
  • beta_schedule – The beta schedule type to use.

  • num_beta_steps – The number of beta steps to use.

  • pred_mode

    The prediction mode, which determines what the model should predict. Can be one of:

    • "pred_x_0": Predicts the initial noise.

    • "pred_eps": Predicts the noise at the current timestep.

    • "pred_v": Predicts the velocity of the noise.

  • loss

    The type of loss to use. Can be one of:

    • "mse": Mean squared error.

    • "l1": Mean absolute error.

  • sigma_type

    The type of sigma to use. Can be one of:

    • "upper_bound": The upper bound of the posterior noise.

    • "lower_bound": The lower bound of the posterior noise.

  • solver – The ODE solver to use for running incremental model steps. If not set, will default to using the built-in ODE solver.

  • beta_start – The initial beta value, for linear, quad, and warmup schedules.

  • beta_end – The final beta value, for linear, quad, warmup and const schedules.

  • warmup – The fraction of timesteps to use for the warmup schedule (between 0 and 1).

  • cosine_offset – The cosine offset, for cosine schedules.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

bar_alpha: Tensor
loss(model: Callable[[Tensor, Tensor], Tensor], x: Tensor) Tensor[source]

Computes the loss for a given sample.

Parameters:
  • model – The model forward process, which takes a tensor with the same shape as the input data plus a timestep and returns the predicted noise or target, with shape (*).

  • x – The input data, with shape (*)

Returns:

The loss, with shape (*).

partial_sample(model: Callable[[Tensor, Tensor], Tensor], reference_sample: Tensor, start_percent: float, sampling_timesteps: int | None = None, solver: BaseODESolver | None = None) Tensor[source]

Samples from the model, starting from a given reference sample.

Partial sampling takes a reference sample, adds some noise to it, then denoises the sample using the model. This can be used for doing style transfer, where the reference sample is the source image which the model redirects to look more like some target style.

Parameters:
  • model – The model forward process, which takes a tensor with the same shape as the input data plus a timestep and returns the predicted noise or target, with shape (*).

  • reference_sample – The reference sample, with shape (*).

  • start_percent – What percent of the diffusion process to start from; 0 means that all of the diffusion steps will be used, while 1 means that none of the diffusion steps will be used.

  • sampling_timesteps – The number of timesteps to sample for. If None, then the full number of timesteps will be used.

  • solver – The ODE solver to use for running incremental model steps. If not set, will default to using the built-in ODE solver.

Returns:

The samples, with shape (sampling_timesteps + 1, *).

sample(model: Callable[[Tensor, Tensor], Tensor], shape: tuple[int, ...], device: device, sampling_timesteps: int | None = None, solver: BaseODESolver | None = None) Tensor[source]

Samples from the model.

Parameters:
  • model – The model forward process, which takes a tensor with the same shape as the input data plus a timestep and returns the predicted noise or target, with shape (*).

  • shape – The shape of the samples.

  • device – The device to put the samples on.

  • sampling_timesteps – The number of timesteps to sample for. If None, then the full number of timesteps will be used.

  • solver – The ODE solver to use for running incremental model steps. If not set, will default to using the built-in ODE solver.

Returns:

The samples, with shape (sampling_timesteps + 1, *).

ml.tasks.diffusion.plot_schedules(*, num_timesteps: int = 100, output_file: str | Path | None = None) None[source]

Plots all of the schedules together on one graph.

Parameters:
  • num_timesteps – The number of timesteps to plot

  • output_file – The file to save the plot to. If None, then the plot will be shown instead.