ml.models.base

The base object and config for all models.

This is essentially just a small wrapper around a vanilla PyTorch module.

class ml.models.base.BaseModelConfig(name: str = '???')[source]

Bases: BaseConfig

Defines the base config for all modules.

ml.models.base.summarize(names: list[tuple[str, torch.device]]) str[source]
class ml.models.base.BaseModel(config: ModelConfigT)[source]

Bases: BaseObject[ModelConfigT], Generic[ModelConfigT], Module

Defines the base module type.

init(device: device, dtype: dtype | None = None) None[source]
get_device() device[source]
get_dtype() dtype[source]
tensor_to(tensor: Tensor, non_blocking: bool = False) Tensor[source]