ml.trainers.mixins.compile
A trainer mixin to support torch.compile
.
By default this is disabled, but can be enabled by setting the environment
variable TORCH_COMPILE=1
or setting trainer.torch_compile.enabled=true
in your configuration file.
- class ml.trainers.mixins.compile.TorchCompileConfig(model: bool = '${oc.decode:${oc.env:COMPILE_MODEL,0}}', func: bool = '${oc.decode:${oc.env:COMPILE_FUNC,0}}', fullgraph: bool = False, dynamic: bool = False, backend: str = 'auto', model_mode: str | None = 'max-autotune', func_mode: str | None = 'reduce-overhead')[source]
Bases:
object
- model: bool = '${oc.decode:${oc.env:COMPILE_MODEL,0}}'
- func: bool = '${oc.decode:${oc.env:COMPILE_FUNC,0}}'
- fullgraph: bool = False
- dynamic: bool = False
- backend: str = 'auto'
- model_mode: str | None = 'max-autotune'
- func_mode: str | None = 'reduce-overhead'
- class ml.trainers.mixins.compile.CompileConfig(name: str = '???', exp_name: str = '${ml.exp_name:null}', exp_dir: str = '???', log_dir_name: str = 'logs', use_double_weight_precision: bool = False, checkpoint: ml.trainers.base.CheckpointConfig = <factory>, compiler: ml.trainers.mixins.compile.TorchCompileConfig = <factory>)[source]
Bases:
BaseTrainerConfig
- compiler: TorchCompileConfig
- class ml.trainers.mixins.compile.CompileMixin(config: TrainerConfigT)[source]
Bases:
BaseTrainer
[CompileConfigT
,ModelT
,TaskT
]Defines a mixin for calling torch.compile on models.