ml.tasks.sl.base

Defines the base supervised learning task type.

This class expects you to implement the following functions:

class MySupervisedLearningTask(ml.SupervisedLearningTask[Config, Model, Batch, Output, Loss]):
    def run_model(self, model: Model, batch: Batch, state: ml.State) -> Output:
        ...

    def compute_loss(self, model: Model, batch: Batch, state: ml.State, output: Output) -> Loss:
        ...

    def get_dataset(self, phase: ml.Phase) -> Dataset:
        ...
class ml.tasks.sl.base.SupervisedLearningTaskConfig(max_epochs: int | None = None, max_steps: int | None = None, max_samples: int | None = None, max_seconds: float | None = None, train_dl: ml.tasks.base.DataLoaderConfig = <factory>, valid_dl: ml.tasks.base.DataLoaderConfig = <factory>, test_dl: ml.tasks.base.DataLoaderConfig = <factory>, name: str = '???', errors: ml.tasks.datasets.error_handling.ErrorHandlingConfig = <factory>)[source]

Bases: BaseTaskConfig

class ml.tasks.sl.base.SupervisedLearningTask(config: BaseTaskConfigT)[source]

Bases: BaseTask[SupervisedLearningTaskConfigT, ModelT, Batch, Output, Loss], Generic[SupervisedLearningTaskConfigT, ModelT, Batch, Output, Loss], ABC

Initializes internal Module state, shared by both nn.Module and ScriptModule.

get_dataset(phase: Literal['train', 'valid', 'test']) Dataset[source]

Returns the dataset for a given phase.

Parameters:

phase – The dataset phase to get