ml.tasks.ode
Defines some components for dealing with ordinary differential equations.
- class ml.tasks.ode.BaseODESolver[source]
Bases:
ABC- abstract step(samples: ~torch.Tensor, t: ~torch.Tensor, next_t: ~torch.Tensor, func: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor], add_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function vanilla_add_fn>) Tensor[source]
Steps the current state forward in time.
- Parameters:
samples – The current samples, with shape
(N, *).t – The current time step, with shape
(N).next_t – The next time step, with shape
(N).func – The function to use to compute the derivative, with signature
(samples, t) -> deriv.add_fn – The addition function to use, which has the signature
(a, b, ta, tb) -> a + b * (tb - ta).
- Returns:
The next sample, with shape
(N, *).
- class ml.tasks.ode.EulerODESolver[source]
Bases:
BaseODESolverThe Euler method for solving ODEs.
- step(samples: ~torch.Tensor, t: ~torch.Tensor, next_t: ~torch.Tensor, func: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor], add_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function vanilla_add_fn>) Tensor[source]
Steps the current state forward in time.
- Parameters:
samples – The current samples, with shape
(N, *).t – The current time step, with shape
(N).next_t – The next time step, with shape
(N).func – The function to use to compute the derivative, with signature
(samples, t) -> deriv.add_fn – The addition function to use, which has the signature
(a, b, ta, tb) -> a + b * (tb - ta).
- Returns:
The next sample, with shape
(N, *).
- class ml.tasks.ode.HeunODESolver[source]
Bases:
BaseODESolverThe Heun method for solving ODEs.
- step(samples: ~torch.Tensor, t: ~torch.Tensor, next_t: ~torch.Tensor, func: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor], add_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function vanilla_add_fn>) Tensor[source]
Steps the current state forward in time.
- Parameters:
samples – The current samples, with shape
(N, *).t – The current time step, with shape
(N).next_t – The next time step, with shape
(N).func – The function to use to compute the derivative, with signature
(samples, t) -> deriv.add_fn – The addition function to use, which has the signature
(a, b, ta, tb) -> a + b * (tb - ta).
- Returns:
The next sample, with shape
(N, *).
- class ml.tasks.ode.RK4ODESolver[source]
Bases:
BaseODESolverThe fourth-order Runge-Kutta method for solving ODEs.
- step(samples: ~torch.Tensor, t: ~torch.Tensor, next_t: ~torch.Tensor, func: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor], add_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function vanilla_add_fn>) Tensor[source]
Steps the current state forward in time.
- Parameters:
samples – The current samples, with shape
(N, *).t – The current time step, with shape
(N).next_t – The next time step, with shape
(N).func – The function to use to compute the derivative, with signature
(samples, t) -> deriv.add_fn – The addition function to use, which has the signature
(a, b, ta, tb) -> a + b * (tb - ta).
- Returns:
The next sample, with shape
(N, *).
- ml.tasks.ode.get_ode_solver(s: Literal['euler', 'heun', 'rk4']) BaseODESolver[source]
Returns an ODE solver for a given key.
- Parameters:
s – The solver key to retrieve.
- Returns:
The solver object.