Source code for ml.launchers.torchrun

"""Defines a launcher which uses `torchrun` to launch a job.

This is a light-weight werapper around PyTorch's `torch.distributed.launch`
script. It is used to launch a job on a single node with multiple processes,
each with multiple devices.
"""

import logging
import shutil
import subprocess
from dataclasses import dataclass

import torch
from omegaconf import MISSING, OmegaConf

from ml.core.config import conf_field
from ml.core.registry import project_dirs, register_launcher, register_trainer
from ml.launchers.base import BaseLauncher, BaseLauncherConfig
from ml.utils.networking import get_unused_port

logger: logging.Logger = logging.getLogger(__name__)

DEFAULT_PORT = 29500

TORCHRUN_TEMPLATE: str = """
#!/usr/bin/env python

from pathlib import Path

from omegaconf import OmegaConf

from ml.core.registry import project_dirs as registry_project_dirs
from ml.scripts.train import train_main

PROJECT_DIRS = {project_root}
CONFIG_PATH = '{config_path}'


def main() -> None:
    for p in PROJECT_DIRS:
        registry_project_dirs.add(Path(p))
    config = OmegaConf.load(CONFIG_PATH)
    train_main(config)


if __name__ == "__main__":
    main()
"""


[docs]@dataclass class TorchRunLauncherConfig(BaseLauncherConfig): nproc_per_node: int = conf_field(MISSING, help="The number of processes per node") master_addr: str = conf_field("127.0.0.1", help="The address of the master") master_port: int = conf_field(MISSING, help="The port of the master") backend: str = conf_field("nccl", help="The backend to use for distributed training") start_method: str = conf_field("spawn", help="The method to use to start processes") torchrun_path: str = conf_field(MISSING, help="The path to the TorchRun script")
[docs] @classmethod def resolve(cls: type["TorchRunLauncherConfig"], config: "TorchRunLauncherConfig") -> None: super().resolve(config) if OmegaConf.is_missing(config, "nproc_per_node"): config.nproc_per_node = torch.cuda.device_count() if OmegaConf.is_missing(config, "master_port"): config.master_port = get_unused_port(DEFAULT_PORT) if OmegaConf.is_missing(config, "torchrun_path"): torchrun_path = shutil.which("torchrun") if torchrun_path is None: raise ValueError("Could not find torchrun in PATH") config.torchrun_path = torchrun_path
[docs]@register_launcher("torchrun", TorchRunLauncherConfig) class TorchRunLauncher(BaseLauncher[TorchRunLauncherConfig]):
[docs] def launch(self) -> None: """Launches the job by calling the TorchRun CLI in a subprocess.""" trainer = register_trainer.build_entry_non_null(self.raw_config) trainer.save_config() # Builds the run file. torchrun_file = TORCHRUN_TEMPLATE.format( project_root=[str(p) for p in project_dirs.paths], config_path=trainer.config_path, ).strip() torchrun_fpath = trainer.exp_dir / "torchrun.py" with open(torchrun_fpath, "w", encoding="utf-8") as f: f.write(torchrun_file) logger.info("Wrote torchrun file to %s", torchrun_fpath) # Makes a specific log directory for TorchRun logs. (log_dir := trainer.log_dir / "torchrun").mkdir(parents=True, exist_ok=True) # This launcher expects to run on only one node. A multi-node TorchRun # launcher would require a way to launch TorchRun processes across # multiple target nodes. node_rank, num_nodes = 0, 1 cmd = [ self.config.torchrun_path, "--nproc-per-node", str(self.config.nproc_per_node), "--node-rank", str(node_rank), "--nnodes", str(num_nodes), "--master-addr", self.config.master_addr, "--master-port", str(self.config.master_port), "--start-method", self.config.start_method, "--log-dir", str(log_dir), "--run-path", str(torchrun_fpath), ] # Launch the job logger.info("Launching job with command: %s", " ".join(cmd)) subprocess.run(cmd, check=True)