ml.core.common_types

Defines common types used in the ML package.

This package makes heavy use of static typing to help with code readability and maintainability. For example, when implementing a new task, you should override the five generic types in the base task, as in the example below:

Batch = Tensor
Output = Tensor
Loss = Tensor

class SomeTask(BaseTask[SomeTaskConfig, SomeModel, Batch, Output, Loss]):
    def run_model(self, model: SomeModel, batch: Batch, state: State) -> Output:
        return model(batch)

    def compute_loss(self, model: SomeModel, batch: Batch, state: State, output: Output) -> Loss:
        return F.mse_loss(output, batch)

This will provide type hints for the task’s methods, so that Mypy or whatever other static analysis tool you use can verify that the types are correct.