"""Defines a launcher for multiprocess training.
This can be used with distributed data parallel (DDP) or fully sharded data
parallel (FSDP) training. The launcher will spawn a process for each device
and initialize the process group for DDP or FSDP training.
This launcher expects to run on a single machine with one or more GPUs.
"""
import functools
import logging
from dataclasses import dataclass
import torch
from omegaconf import DictConfig
from ml.core.config import conf_field
from ml.core.registry import register_launcher
from ml.launchers.base import BaseLauncher, BaseLauncherConfig
from ml.scripts.train import train_main
from ml.utils.torch_distributed import MultiprocessConfig, launch_subprocesses
logger: logging.Logger = logging.getLogger(__name__)
[docs]def process_main(cfg: MultiprocessConfig, raw_config: DictConfig) -> None:
train_main(raw_config)
[docs]@dataclass
class MultiProcessLauncherConfig(BaseLauncherConfig):
multiprocess: MultiprocessConfig = conf_field(MultiprocessConfig())
[docs] @classmethod
def resolve(cls: type["MultiProcessLauncherConfig"], config: "MultiProcessLauncherConfig") -> None:
super().resolve(config)
# Resolve multiprocess config.
MultiprocessConfig.resolve(config.multiprocess)
[docs]@register_launcher("mp", MultiProcessLauncherConfig)
class MultiProcessLauncher(BaseLauncher[MultiProcessLauncherConfig]):
[docs] def launch(self) -> None:
if not torch.cuda.is_available():
logger.warning("MultiProcessLauncher expects CUDA")
func = functools.partial(
process_main,
cfg=self.config.multiprocess,
raw_config=self.raw_config,
)
launch_subprocesses(func, self.config.multiprocess)