ml.tasks.base

Defines the base class and config for all tasks.

Tasks are the main unit of work in the ML framework. They are responsible for defining the training, validation, and testing loops, as well as data loading, logging, and model evaluation. They also do a lot of timing and logging of performance metrics, with some hooks for calling custom code snippets at various points. Typically, you should use either the ml.tasks.sl.SupervisedLearningTask or ml.tasks.rl.ReinforcementLearningTask classes instead of this base class.

ml.tasks.base.num_workers(default: int) int[source]
class ml.tasks.base.CumulativeTimer[source]

Bases: object

Defines a simple timer to track an average value.

property start_time: float
step(steps: int, cur_time: float) None[source]
property steps_per_second: float
property steps_per_hour: float
property seconds_per_step: float
property hours_per_step: float
class ml.tasks.base.IterationTimer[source]

Bases: object

Defines a simple timer to track consecutive values.

step(cur_time: float) None[source]
property iter_seconds: float
property iter_hours: float
class ml.tasks.base.StateTimer[source]

Bases: object

Defines a timer for all state information.

step(state: State) None[source]
log_dict() dict[str, dict[str, int | float]][source]
class ml.tasks.base.DataLoaderConfig(batch_size: int = '???', batch_size_multiplier: float = '???', shuffle: bool = '???', num_workers: int = '???', pin_memory: bool = '???', drop_last: bool = '???', timeout: float = 0, prefetch_factor: int | None = None, persistent_workers: bool = False, seed: int = 1337)[source]

Bases: object

batch_size: int = '???'
batch_size_multiplier: float = '???'
shuffle: bool = '???'
num_workers: int = '???'
pin_memory: bool = '???'
drop_last: bool = '???'
timeout: float = 0
prefetch_factor: int | None = None
persistent_workers: bool = False
seed: int = 1337
class ml.tasks.base.DataLoaderConfigs(train_dl: ml.tasks.base.DataLoaderConfig = <factory>, valid_dl: ml.tasks.base.DataLoaderConfig = <factory>, test_dl: ml.tasks.base.DataLoaderConfig = <factory>)[source]

Bases: object

train_dl: DataLoaderConfig
valid_dl: DataLoaderConfig
test_dl: DataLoaderConfig
class ml.tasks.base.FinishTrainingConfig(max_epochs: int | None = None, max_steps: int | None = None, max_samples: int | None = None, max_seconds: float | None = None)[source]

Bases: object

max_epochs: int | None = None
max_steps: int | None = None
max_samples: int | None = None
max_seconds: float | None = None
class ml.tasks.base.BaseTaskConfig(max_epochs: int | None = None, max_steps: int | None = None, max_samples: int | None = None, max_seconds: float | None = None, train_dl: ~ml.tasks.base.DataLoaderConfig = <factory>, valid_dl: ~ml.tasks.base.DataLoaderConfig = <factory>, test_dl: ~ml.tasks.base.DataLoaderConfig = <factory>, name: str = '???', errors: ~ml.tasks.datasets.error_handling.ErrorHandlingConfig = <factory>)[source]

Bases: BaseConfig, DataLoaderConfigs, FinishTrainingConfig

Defines the base config for all tasks.

errors: ErrorHandlingConfig
class ml.tasks.base.BaseTask(config: BaseTaskConfigT)[source]

Bases: Module, BaseObject[BaseTaskConfigT], Generic[BaseTaskConfigT, ModelT, Batch, Output, Loss], ABC

Defines the base task type.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

abstract run_model(model: ModelT, batch: Batch, state: State) Output[source]

Runs a single training step and returns the outputs.

Parameters:
  • model – The current nn.Module

  • batch – The current batch

  • state – The current trainer state

Returns:

The outputs from the model

abstract compute_loss(model: ModelT, batch: Batch, state: State, output: Output) Loss[source]

Computes the loss for a given output.

If the loss is a tensor, it should have shape (B). If the loss is a dictionary of tensors, each tensor should have the same shape (B).

Parameters:
  • model – The current nn.Module

  • batch – The current batch

  • state – The current trainer state

  • output – The model output from run_model

Returns:

The computed loss, as a tensor or dictionary of tensors

get_single_loss(loss: Loss) tuple[torch.Tensor, list[str]][source]

Combines the output losses to get a single loss with shape (N, B).

Parameters:

loss – The computed loss or losses, either a tensor or dictionary of tensors. If a dictionary, all loss tensors need to have the same shape.

Returns:

The single loss with shape (N), where N is the number of losses, and the loss names, a list of length N.

log_loss_dict(loss: Mapping[str, int | float | Tensor], state: State) None[source]
get_batch_size(batch: Batch) int | None[source]
set_training_over() None[source]
maybe_log_termination_time(remaining_percent: float, state: State) None[source]
get_remaining_percent(state: State) float | None[source]
is_training_over(state: State) bool[source]
get_sampler(dataset: Dataset, cfg: DataLoaderConfig, phase: Literal['train', 'valid', 'test']) Sampler[int][source]

Returns a dataset sampler to use instead of random sampling.

The default behavior for a non-iterable dataset is to use a RandomSampler for all the elements from the dataset. The sampler should yield integer indices into the dataset.

Parameters:
  • dataset – The dataset to sample from

  • cfg – The associated dataloader config

  • phase – The dataset’s phase

Raises:

NotImplementedError – If this method is not overridden

get_batch_sampler(sampler: Sampler, cfg: DataLoaderConfig, phase: Literal['train', 'valid', 'test']) Sampler[list[int]][source]

Returns a dataset batch sampler to use instead fo sequential sampling.

The batch sampler should yield lists of integer indices, which are the samples that are passed to the dataset.

Parameters:
  • sampler – The underlying sampler

  • cfg – The associated dataloader config

  • phase – The dataset’s phase

Raises:

NotImplementedError – If this method is not overridden

apply_datapipe_transformations(datapipe: DataPipeT, phase: Literal['train', 'valid', 'test']) DataPipeT[source]

Applies transformations to the datapipe.

Parameters:
  • datapipe – The datapipe to transform

  • phase – The dataset’s phase

Returns:

The transformed datapipe

get_datapipe_dataloader(datapipe: MapDataPipe | IterDataPipe, phase: Literal['train', 'valid', 'test']) DataLoader[source]
get_dataloader(dataset: Dataset, phase: Literal['train', 'valid', 'test']) DataLoader[source]
classmethod worker_init_fn(worker_id: int) None[source]
classmethod collate_fn(items: list[Any], *, mode: Literal['stack', 'concat'] = 'stack') Any | None[source]
on_after_save_checkpoint(ckpt_path: Path) None[source]
on_before_forward_step(model: ModelT, batch: Batch, state: State) None[source]
on_after_forward_step(model: ModelT, batch: Batch, output: Output, state: State) None[source]
on_after_compute_loss(model: ModelT, batch: Batch, output: Output, loss: Loss, state: State) None[source]
on_step_start(state: State, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None[source]
on_step_end(state: State, loss_dict: dict[str, torch.Tensor], model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None[source]
on_epoch_start(state: State, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None[source]
on_epoch_end(state: State, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None[source]
on_training_start(state: State, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None[source]
on_training_end(state: State, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None[source]