ml.trainers.mixins.compile

A trainer mixin to support torch.compile.

By default this is disabled, but can be enabled by setting the environment variable TORCH_COMPILE=1 or setting trainer.torch_compile.enabled=true in your configuration file.

class ml.trainers.mixins.compile.TorchCompileConfig(model: bool = '${oc.decode:${oc.env:COMPILE_MODEL,0}}', func: bool = '${oc.decode:${oc.env:COMPILE_FUNC,0}}', fullgraph: bool = False, dynamic: bool = False, backend: str = 'auto', model_mode: str | None = 'max-autotune', func_mode: str | None = 'reduce-overhead')[source]

Bases: object

model: bool = '${oc.decode:${oc.env:COMPILE_MODEL,0}}'
func: bool = '${oc.decode:${oc.env:COMPILE_FUNC,0}}'
fullgraph: bool = False
dynamic: bool = False
backend: str = 'auto'
model_mode: str | None = 'max-autotune'
func_mode: str | None = 'reduce-overhead'
class ml.trainers.mixins.compile.CompileConfig(name: str = '???', exp_name: str = '${ml.exp_name:null}', exp_dir: str = '???', log_dir_name: str = 'logs', use_double_weight_precision: bool = False, checkpoint: ml.trainers.base.CheckpointConfig = <factory>, compiler: ml.trainers.mixins.compile.TorchCompileConfig = <factory>)[source]

Bases: BaseTrainerConfig

compiler: TorchCompileConfig
class ml.trainers.mixins.compile.CompileMixin(config: TrainerConfigT)[source]

Bases: BaseTrainer[CompileConfigT, ModelT, TaskT]

Defines a mixin for calling torch.compile on models.