ml.tasks.sl.base
Defines the base supervised learning task type.
This class expects you to implement the following functions:
class MySupervisedLearningTask(ml.SupervisedLearningTask[Config, Model, Batch, Output, Loss]):
def run_model(self, model: Model, batch: Batch, state: ml.State) -> Output:
...
def compute_loss(self, model: Model, batch: Batch, state: ml.State, output: Output) -> Loss:
...
def get_dataset(self, phase: ml.Phase) -> Dataset:
...
- class ml.tasks.sl.base.SupervisedLearningTaskConfig(max_epochs: int | None = None, max_steps: int | None = None, max_samples: int | None = None, max_seconds: float | None = None, train_dl: ml.tasks.base.DataLoaderConfig = <factory>, valid_dl: ml.tasks.base.DataLoaderConfig = <factory>, test_dl: ml.tasks.base.DataLoaderConfig = <factory>, name: str = '???', errors: ml.tasks.datasets.error_handling.ErrorHandlingConfig = <factory>)[source]
Bases:
BaseTaskConfig
- class ml.tasks.sl.base.SupervisedLearningTask(config: BaseTaskConfigT)[source]
Bases:
BaseTask
[SupervisedLearningTaskConfigT
,ModelT
,Batch
,Output
,Loss
],Generic
[SupervisedLearningTaskConfigT
,ModelT
,Batch
,Output
,Loss
],ABC
Initializes internal Module state, shared by both nn.Module and ScriptModule.