ml.tasks.base
Defines the base class and config for all tasks.
Tasks are the main unit of work in the ML framework. They are responsible for
defining the training, validation, and testing loops, as well as data loading,
logging, and model evaluation. They also do a lot of timing and logging of
performance metrics, with some hooks for calling custom code snippets at
various points. Typically, you should use either the
ml.tasks.sl.SupervisedLearningTask
or
ml.tasks.rl.ReinforcementLearningTask
classes instead of this base
class.
- class ml.tasks.base.CumulativeTimer[source]
Bases:
object
Defines a simple timer to track an average value.
- property start_time: float
- property steps_per_second: float
- property steps_per_hour: float
- property seconds_per_step: float
- property hours_per_step: float
- class ml.tasks.base.IterationTimer[source]
Bases:
object
Defines a simple timer to track consecutive values.
- property iter_seconds: float
- property iter_hours: float
- class ml.tasks.base.DataLoaderConfig(batch_size: int = '???', batch_size_multiplier: float = '???', shuffle: bool = '???', num_workers: int = '???', pin_memory: bool = '???', drop_last: bool = '???', timeout: float = 0, prefetch_factor: int | None = None, persistent_workers: bool = False, seed: int = 1337)[source]
Bases:
object
- batch_size: int = '???'
- batch_size_multiplier: float = '???'
- shuffle: bool = '???'
- num_workers: int = '???'
- pin_memory: bool = '???'
- drop_last: bool = '???'
- timeout: float = 0
- prefetch_factor: int | None = None
- persistent_workers: bool = False
- seed: int = 1337
- class ml.tasks.base.DataLoaderConfigs(train_dl: ml.tasks.base.DataLoaderConfig = <factory>, valid_dl: ml.tasks.base.DataLoaderConfig = <factory>, test_dl: ml.tasks.base.DataLoaderConfig = <factory>)[source]
Bases:
object
- train_dl: DataLoaderConfig
- valid_dl: DataLoaderConfig
- test_dl: DataLoaderConfig
- class ml.tasks.base.FinishTrainingConfig(max_epochs: int | None = None, max_steps: int | None = None, max_samples: int | None = None, max_seconds: float | None = None)[source]
Bases:
object
- max_epochs: int | None = None
- max_steps: int | None = None
- max_samples: int | None = None
- max_seconds: float | None = None
- class ml.tasks.base.BaseTaskConfig(max_epochs: int | None = None, max_steps: int | None = None, max_samples: int | None = None, max_seconds: float | None = None, train_dl: ~ml.tasks.base.DataLoaderConfig = <factory>, valid_dl: ~ml.tasks.base.DataLoaderConfig = <factory>, test_dl: ~ml.tasks.base.DataLoaderConfig = <factory>, name: str = '???', errors: ~ml.tasks.datasets.error_handling.ErrorHandlingConfig = <factory>)[source]
Bases:
BaseConfig
,DataLoaderConfigs
,FinishTrainingConfig
Defines the base config for all tasks.
- errors: ErrorHandlingConfig
- class ml.tasks.base.BaseTask(config: BaseTaskConfigT)[source]
Bases:
Module
,BaseObject
[BaseTaskConfigT
],Generic
[BaseTaskConfigT
,ModelT
,Batch
,Output
,Loss
],ABC
Defines the base task type.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- abstract run_model(model: ModelT, batch: Batch, state: State) Output [source]
Runs a single training step and returns the outputs.
- Parameters:
model – The current nn.Module
batch – The current batch
state – The current trainer state
- Returns:
The outputs from the model
- abstract compute_loss(model: ModelT, batch: Batch, state: State, output: Output) Loss [source]
Computes the loss for a given output.
If the loss is a tensor, it should have shape (B). If the loss is a dictionary of tensors, each tensor should have the same shape (B).
- Parameters:
model – The current nn.Module
batch – The current batch
state – The current trainer state
output – The model output from run_model
- Returns:
The computed loss, as a tensor or dictionary of tensors
- get_single_loss(loss: Loss) tuple[torch.Tensor, list[str]] [source]
Combines the output losses to get a single loss with shape (N, B).
- Parameters:
loss – The computed loss or losses, either a tensor or dictionary of tensors. If a dictionary, all loss tensors need to have the same shape.
- Returns:
The single loss with shape (N), where N is the number of losses, and the loss names, a list of length N.
- get_sampler(dataset: Dataset, cfg: DataLoaderConfig, phase: Literal['train', 'valid', 'test']) Sampler[int] [source]
Returns a dataset sampler to use instead of random sampling.
The default behavior for a non-iterable dataset is to use a RandomSampler for all the elements from the dataset. The sampler should yield integer indices into the dataset.
- Parameters:
dataset – The dataset to sample from
cfg – The associated dataloader config
phase – The dataset’s phase
- Raises:
NotImplementedError – If this method is not overridden
- get_batch_sampler(sampler: Sampler, cfg: DataLoaderConfig, phase: Literal['train', 'valid', 'test']) Sampler[list[int]] [source]
Returns a dataset batch sampler to use instead fo sequential sampling.
The batch sampler should yield lists of integer indices, which are the samples that are passed to the dataset.
- Parameters:
sampler – The underlying sampler
cfg – The associated dataloader config
phase – The dataset’s phase
- Raises:
NotImplementedError – If this method is not overridden
- apply_datapipe_transformations(datapipe: DataPipeT, phase: Literal['train', 'valid', 'test']) DataPipeT [source]
Applies transformations to the datapipe.
- Parameters:
datapipe – The datapipe to transform
phase – The dataset’s phase
- Returns:
The transformed datapipe
- get_datapipe_dataloader(datapipe: MapDataPipe | IterDataPipe, phase: Literal['train', 'valid', 'test']) DataLoader [source]
- classmethod collate_fn(items: list[Any], *, mode: Literal['stack', 'concat'] = 'stack') Any | None [source]
- on_after_compute_loss(model: ModelT, batch: Batch, output: Output, loss: Loss, state: State) None [source]
- on_step_start(state: State, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None [source]
- on_step_end(state: State, loss_dict: dict[str, torch.Tensor], model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None [source]
- on_epoch_start(state: State, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None [source]
- on_epoch_end(state: State, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None [source]
- on_training_start(state: State, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None [source]
- on_training_end(state: State, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None [source]