ml.utils.amp

Helper functions for mixed-precision training.

class ml.utils.amp.autocast_all(device_types: list[str] | None = None, enabled: bool = True, cache_enabled: bool | None = None)[source]

Bases: object

ml.utils.amp.default_device() str[source]
ml.utils.amp.default_dtype(enabled: bool) dtype[source]
class ml.utils.amp.autocast_tensors(xs: T | None = None, device_type: str | None = None, dtype: dtype | None = None, enabled: bool = True, cache_enabled: bool | None = None)[source]

Bases: Generic[T]

Defines a context manager for enabling or disabling autocasting.

This context manager simultaneously converts a tensor or container of tensors to the dtype that the device expects. For example, if enabling autocast, it will convert the tensor or tensors to whatever the default floating point type is for the device (typically FP16 or BF16). If disabling, it will convert the tensor or tensors to FP32.

Parameters:
  • xs – The tensor or container of tensors to autocast.

  • device_type – The device type to use for autocasting. If not specified, the default device type will be used.

  • dtype – The dtype to use for autocasting. If not specified, the default dtype will be used.

  • enabled – Whether to enable or disable autocasting.

  • cache_enabled – Whether to enable or disable the cache for autocasting. If not specified, the default cache setting will be used.

apply(xs: Any) Any[source]