# mypy: disable-error-code="override"
"""Defines utilities for training distributed PyTorch models.
The canonical way to use this library is to call :func:`launch_subprocess`
from the main function of your script. This will launch a subprocess for each
device and initialize the process group for distributed training. You can
modify the number of processes and the backend by changing the provided config.
"""
import functools
import logging
import os
import sys
import traceback
from dataclasses import dataclass
from typing import Callable, ParamSpec
import torch
import torch.distributed
import torch.distributed as dist
import torch.multiprocessing as mp
from omegaconf import MISSING
from torch import Tensor
from torch.autograd.function import Function, FunctionCtx
from ml.core.config import conf_field
from ml.utils.checkpoint import is_missing
from ml.utils.distributed import get_init_method, get_rank, get_world_size, set_dist
from ml.utils.logging import INFOALL, configure_logging
from ml.utils.networking import get_unused_port
from ml.utils.parallel import init_parallelism
DEFAULT_PORT = 29500
logger: logging.Logger = logging.getLogger(__name__)
P = ParamSpec("P")
[docs]@dataclass
class MultiprocessConfig:
rank: int = conf_field(-1, help="The rank of the process")
local_rank: int = conf_field(-1, help="The local rank of the process")
world_size: int = conf_field(MISSING, help="The total number of processes")
local_world_size: int = conf_field(MISSING, help="The number of processes per machine")
master_addr: str = conf_field("127.0.0.1", help="The address of the master process")
master_port: int = conf_field(MISSING, help="The port of the master process")
init_method: str = conf_field("env://", help="The initialization method")
model_parallelism: int = conf_field(1, help="The number of model parallel processes")
pipeline_parallelism: int = conf_field(1, help="The number of pipeline parallel processes")
backend: str | None = conf_field(None, help="The distributed backend")
model_parallel_backend: str | None = conf_field(None, help="The model parallel backend")
pipeline_parallel_backend: str | None = conf_field(None, help="The pipeline parallel backend")
data_parallel_backend: str | None = conf_field(None, help="The data parallel backend")
launch_method: str = conf_field("forkserver", help="The launch method for multiprocessing")
[docs] @classmethod
def resolve(cls, config: "MultiprocessConfig") -> None:
device_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
if is_missing(config, "world_size"):
config.world_size = device_count
if is_missing(config, "local_world_size"):
config.local_world_size = min(device_count, config.world_size)
if is_missing(config, "master_port"):
config.master_port = get_unused_port(DEFAULT_PORT)
[docs]def init_process_group_from_backend(backend: str | dist.Backend | None = None) -> None:
if backend is None:
backend = get_distributed_backend()
init_method, world_size, rank = get_init_method(), get_world_size(), get_rank()
logger.log(INFOALL, "Initializing %d / %d using %s - %s", rank, world_size, init_method, backend)
dist.init_process_group(backend=backend, init_method=init_method, world_size=world_size, rank=rank)
if torch.cuda.is_available():
device_count = torch.cuda.device_count()
torch.cuda.set_device(rank % device_count)
logger.info("Initialized process group; running dummy all-reduce")
dist.all_reduce(torch.zeros(1, device="cuda" if torch.cuda.is_available() else "cpu"))
logger.info("Dummy all-reduce succeeded")
[docs]def init_dist(
rank: int,
local_rank: int,
world_size: int,
local_world_size: int,
master_addr: str,
master_port: int,
init_method: str,
backend: str | dist.Backend | None = None,
) -> None:
"""Initializes distributed environment.
Args:
rank: The rank of the current process.
local_rank: The local rank of the current process.
world_size: The total number of processes.
local_world_size: The number of processes per machine.
master_addr: The address of the master process.
master_port: The port of the master process.
init_method: The initialization method.
backend: The distributed backend.
"""
set_dist(rank, local_rank, world_size, local_world_size, master_addr, master_port, init_method)
init_process_group_from_backend(backend)
[docs]@functools.lru_cache
def default_backend() -> str:
if torch.cuda.is_available():
return "nccl"
return "gloo"
[docs]def get_distributed_backend() -> dist.Backend:
# Used to change the distributed backend to something other than NCCL.
# For example, if you're on a system with some strange NCCL errors, you
# can try changing this environment variable to `gloo`.
return dist.Backend(os.environ.get("TORCH_DISTRIBUTED_BACKEND", default_backend()))
[docs]def set_distributed_backend(backend: str) -> None:
os.environ["TORCH_DISTRIBUTED_BACKEND"] = backend
[docs]def init_and_run(
func: Callable[P, None],
cfg: MultiprocessConfig,
*args: P.args,
**kwargs: P.kwargs,
) -> None:
configure_logging(rank=cfg.rank, world_size=cfg.world_size)
init_dist(
rank=cfg.rank,
local_rank=cfg.local_rank,
world_size=cfg.world_size,
local_world_size=cfg.local_world_size,
master_addr=cfg.master_addr,
master_port=cfg.master_port,
init_method=cfg.init_method,
backend=cfg.backend,
)
init_parallelism(
model_parallelism=cfg.model_parallelism,
pipeline_parallelism=cfg.pipeline_parallelism,
mp_backend=cfg.backend if cfg.model_parallel_backend is None else cfg.model_parallel_backend,
pp_backend=cfg.backend if cfg.pipeline_parallel_backend is None else cfg.pipeline_parallel_backend,
dp_backend=cfg.backend if cfg.data_parallel_backend is None else cfg.data_parallel_backend,
)
func(*args, **kwargs)
def _func_wrapped(
func: Callable[P, None],
setup: Callable[[], None] | None,
cfg: MultiprocessConfig,
error_queue: "mp.SimpleQueue[str | None]",
*args: P.args,
**kwargs: P.kwargs,
) -> None:
try:
if setup is not None:
setup()
init_and_run(func, cfg, *args, **kwargs)
except KeyboardInterrupt:
logger.info("Caught KeyboardInterrupt; exiting")
except Exception:
error_queue.put(traceback.format_exc())
sys.exit(1)
error_queue.put(None)
[docs]def launch_subprocesses(
func: Callable[P, None],
cfg: MultiprocessConfig,
setup: Callable[[], None] | None = None,
rank_offset: int = 0,
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""Launches a function in multiple subprocesses.
Args:
func: The function to launch.
cfg: The configuration for the function.
args: The positional arguments to pass to the function.
setup: A function to run before launching the subprocesses.
rank_offset: The offset to add to the rank of each subprocess.
kwargs: The keyword arguments to pass to the function.
Raises:
RuntimeError: If the function fails in any subprocess.
"""
MultiprocessConfig.resolve(cfg)
if cfg.world_size <= 1:
logger.warning("Multi-process trainer expects more than one device; running single-process")
cfg.rank = 0
init_and_run(func, cfg, *args, **kwargs)
return
logger.info("Launching %d training workers", cfg.world_size)
ctx = mp.get_context(cfg.launch_method)
error_queues: list["mp.SimpleQueue[str | None]"] = []
procs = []
for rank in range(cfg.world_size):
rank = rank + rank_offset
error_queue = ctx.SimpleQueue()
cfg.rank = rank
cfg.local_rank = rank % cfg.local_world_size
proc = ctx.Process(
target=_func_wrapped,
args=[func, setup, cfg, error_queue, *args],
kwargs=kwargs,
daemon=False,
name=f"worker-{rank}",
)
logger.debug("Started process %d", rank)
proc.start()
error_queues.append(error_queue)
procs.append(proc)
pctx = mp.ProcessContext(procs, error_queues)
while not pctx.join():
pass
for rank, error_queue in enumerate(error_queues):
error = error_queue.get()
if error:
raise RuntimeError(f"Process {rank} failed with error:\n{error}")
class _AllToAll(Function):
@staticmethod
def forward(ctx: FunctionCtx, group: dist.ProcessGroup, input: Tensor) -> Tensor:
ctx.group = group
input = input.contiguous()
output = torch.empty_like(input)
if dist.is_initialized():
dist.all_to_all_single(output, input, group=group)
else:
assert group is None
output = input
return output
@staticmethod
def backward(ctx: FunctionCtx, *grad_output: Tensor) -> tuple[None, Tensor]:
return (None, _AllToAll.apply(ctx.group, *grad_output))
[docs]def all_to_all(input: Tensor, group: dist.ProcessGroup | None) -> Tensor:
"""Performs an all-to-all operation on the input tensor.
Args:
input: The input tensor.
group: The process group to use for the all-to-all operation.
Returns:
The output tensor.
"""
if group is None:
group = dist.group.WORLD
return _AllToAll.apply(group, input)