ml.tasks.ode
Defines some components for dealing with ordinary differential equations.
- class ml.tasks.ode.BaseODESolver[source]
Bases:
ABC
- abstract step(samples: ~torch.Tensor, t: ~torch.Tensor, next_t: ~torch.Tensor, func: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor], add_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function vanilla_add_fn>) Tensor [source]
Steps the current state forward in time.
- Parameters:
samples – The current samples, with shape
(N, *)
.t – The current time step, with shape
(N)
.next_t – The next time step, with shape
(N)
.func – The function to use to compute the derivative, with signature
(samples, t) -> deriv
.add_fn – The addition function to use, which has the signature
(a, b, ta, tb) -> a + b * (tb - ta)
.
- Returns:
The next sample, with shape
(N, *)
.
- class ml.tasks.ode.EulerODESolver[source]
Bases:
BaseODESolver
The Euler method for solving ODEs.
- step(samples: ~torch.Tensor, t: ~torch.Tensor, next_t: ~torch.Tensor, func: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor], add_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function vanilla_add_fn>) Tensor [source]
Steps the current state forward in time.
- Parameters:
samples – The current samples, with shape
(N, *)
.t – The current time step, with shape
(N)
.next_t – The next time step, with shape
(N)
.func – The function to use to compute the derivative, with signature
(samples, t) -> deriv
.add_fn – The addition function to use, which has the signature
(a, b, ta, tb) -> a + b * (tb - ta)
.
- Returns:
The next sample, with shape
(N, *)
.
- class ml.tasks.ode.HeunODESolver[source]
Bases:
BaseODESolver
The Heun method for solving ODEs.
- step(samples: ~torch.Tensor, t: ~torch.Tensor, next_t: ~torch.Tensor, func: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor], add_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function vanilla_add_fn>) Tensor [source]
Steps the current state forward in time.
- Parameters:
samples – The current samples, with shape
(N, *)
.t – The current time step, with shape
(N)
.next_t – The next time step, with shape
(N)
.func – The function to use to compute the derivative, with signature
(samples, t) -> deriv
.add_fn – The addition function to use, which has the signature
(a, b, ta, tb) -> a + b * (tb - ta)
.
- Returns:
The next sample, with shape
(N, *)
.
- class ml.tasks.ode.RK4ODESolver[source]
Bases:
BaseODESolver
The fourth-order Runge-Kutta method for solving ODEs.
- step(samples: ~torch.Tensor, t: ~torch.Tensor, next_t: ~torch.Tensor, func: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor], add_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function vanilla_add_fn>) Tensor [source]
Steps the current state forward in time.
- Parameters:
samples – The current samples, with shape
(N, *)
.t – The current time step, with shape
(N)
.next_t – The next time step, with shape
(N)
.func – The function to use to compute the derivative, with signature
(samples, t) -> deriv
.add_fn – The addition function to use, which has the signature
(a, b, ta, tb) -> a + b * (tb - ta)
.
- Returns:
The next sample, with shape
(N, *)
.
- ml.tasks.ode.get_ode_solver(s: Literal['euler', 'heun', 'rk4']) BaseODESolver [source]
Returns an ODE solver for a given key.
- Parameters:
s – The solver key to retrieve.
- Returns:
The solver object.