ml.tasks.ode

Defines some components for dealing with ordinary differential equations.

ml.tasks.ode.cast_solver_type(s: str) Literal['euler', 'heun', 'rk4'][source]
ml.tasks.ode.vanilla_add_fn(a: Tensor, b: Tensor, ta: Tensor, tb: Tensor) Tensor[source]
class ml.tasks.ode.BaseODESolver[source]

Bases: ABC

abstract step(samples: ~torch.Tensor, t: ~torch.Tensor, next_t: ~torch.Tensor, func: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor], add_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function vanilla_add_fn>) Tensor[source]

Steps the current state forward in time.

Parameters:
  • samples – The current samples, with shape (N, *).

  • t – The current time step, with shape (N).

  • next_t – The next time step, with shape (N).

  • func – The function to use to compute the derivative, with signature (samples, t) -> deriv.

  • add_fn – The addition function to use, which has the signature (a, b, ta, tb) -> a + b * (tb - ta).

Returns:

The next sample, with shape (N, *).

class ml.tasks.ode.EulerODESolver[source]

Bases: BaseODESolver

The Euler method for solving ODEs.

step(samples: ~torch.Tensor, t: ~torch.Tensor, next_t: ~torch.Tensor, func: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor], add_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function vanilla_add_fn>) Tensor[source]

Steps the current state forward in time.

Parameters:
  • samples – The current samples, with shape (N, *).

  • t – The current time step, with shape (N).

  • next_t – The next time step, with shape (N).

  • func – The function to use to compute the derivative, with signature (samples, t) -> deriv.

  • add_fn – The addition function to use, which has the signature (a, b, ta, tb) -> a + b * (tb - ta).

Returns:

The next sample, with shape (N, *).

class ml.tasks.ode.HeunODESolver[source]

Bases: BaseODESolver

The Heun method for solving ODEs.

step(samples: ~torch.Tensor, t: ~torch.Tensor, next_t: ~torch.Tensor, func: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor], add_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function vanilla_add_fn>) Tensor[source]

Steps the current state forward in time.

Parameters:
  • samples – The current samples, with shape (N, *).

  • t – The current time step, with shape (N).

  • next_t – The next time step, with shape (N).

  • func – The function to use to compute the derivative, with signature (samples, t) -> deriv.

  • add_fn – The addition function to use, which has the signature (a, b, ta, tb) -> a + b * (tb - ta).

Returns:

The next sample, with shape (N, *).

class ml.tasks.ode.RK4ODESolver[source]

Bases: BaseODESolver

The fourth-order Runge-Kutta method for solving ODEs.

step(samples: ~torch.Tensor, t: ~torch.Tensor, next_t: ~torch.Tensor, func: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor], add_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function vanilla_add_fn>) Tensor[source]

Steps the current state forward in time.

Parameters:
  • samples – The current samples, with shape (N, *).

  • t – The current time step, with shape (N).

  • next_t – The next time step, with shape (N).

  • func – The function to use to compute the derivative, with signature (samples, t) -> deriv.

  • add_fn – The addition function to use, which has the signature (a, b, ta, tb) -> a + b * (tb - ta).

Returns:

The next sample, with shape (N, *).

ml.tasks.ode.get_ode_solver(s: Literal['euler', 'heun', 'rk4']) BaseODESolver[source]

Returns an ODE solver for a given key.

Parameters:

s – The solver key to retrieve.

Returns:

The solver object.