Source code for ml.tasks.sl.base

"""Defines the base supervised learning task type.

This class expects you to implement the following functions:

.. code-block:: python

    class MySupervisedLearningTask(ml.SupervisedLearningTask[Config, Model, Batch, Output, Loss]):
        def run_model(self, model: Model, batch: Batch, state: ml.State) -> Output:
            ...

        def compute_loss(self, model: Model, batch: Batch, state: ml.State, output: Output) -> Loss:
            ...

        def get_dataset(self, phase: ml.Phase) -> Dataset:
            ...
"""

import logging
from abc import ABC
from dataclasses import dataclass
from typing import Generic, TypeVar

from torch.utils.data.dataset import Dataset

from ml.core.common_types import Batch, Loss, Output
from ml.core.state import Phase
from ml.models.base import BaseModel
from ml.tasks.base import BaseTask, BaseTaskConfig

logger: logging.Logger = logging.getLogger(__name__)


[docs]@dataclass class SupervisedLearningTaskConfig(BaseTaskConfig): pass
SupervisedLearningTaskConfigT = TypeVar("SupervisedLearningTaskConfigT", bound=SupervisedLearningTaskConfig) ModelT = TypeVar("ModelT", bound=BaseModel)
[docs]class SupervisedLearningTask( BaseTask[SupervisedLearningTaskConfigT, ModelT, Batch, Output, Loss], Generic[SupervisedLearningTaskConfigT, ModelT, Batch, Output, Loss], ABC, ):
[docs] def get_dataset(self, phase: Phase) -> Dataset: """Returns the dataset for a given phase. Args: phase: The dataset phase to get """ raise NotImplementedError("`get_dataset` should be implemented by the task")