"""Defines distributed training parameters.
These parameters apply to any distributed training jobs. For model-parallel
training, please refer to :mod:`ml.models.parallel.env`.
- ``RANK``: The rank of the current process.
- ``WORLD_SIZE``: The total number of processes.
- ``MASTER_ADDR``: The address of the master process.
- ``MASTER_PORT``: The port of the master process.
- ``INIT_METHOD``: The method to initialize the process group.
"""
import os
import random
import time
_RANK: int | None = None
_LOCAL_RANK: int | None = None
_WORLD_SIZE: int | None = None
_LOCAL_WORLD_SIZE: int | None = None
_MASTER_ADDR: str | None = None
_MASTER_PORT: int | None = None
_INIT_METHOD: str | None = None
[docs]def set_rank(rank: int) -> None:
global _RANK
if rank != _RANK:
_RANK = rank
os.environ["RANK"] = str(rank)
else:
raise ValueError(f"Rank {rank} is already set")
[docs]def get_rank_optional() -> int | None:
return _RANK
[docs]def get_rank() -> int:
return 0 if _RANK is None else _RANK
[docs]def set_local_rank(rank: int) -> None:
global _LOCAL_RANK
if rank != _LOCAL_RANK:
_LOCAL_RANK = rank
os.environ["LOCAL_RANK"] = str(rank)
else:
raise ValueError(f"Local rank {rank} is already set")
[docs]def get_local_rank_optional() -> int | None:
return _LOCAL_RANK
[docs]def get_local_rank() -> int:
return 0 if _LOCAL_RANK is None else _LOCAL_RANK
[docs]def set_world_size(world_size: int) -> None:
global _WORLD_SIZE
if world_size != _WORLD_SIZE:
_WORLD_SIZE = world_size
os.environ["WORLD_SIZE"] = str(world_size)
else:
raise ValueError(f"World size {world_size} is already set")
[docs]def get_world_size_optional() -> int | None:
return _WORLD_SIZE
[docs]def get_world_size() -> int:
return 1 if _WORLD_SIZE is None else _WORLD_SIZE
[docs]def set_local_world_size(local_world_size: int) -> None:
global _LOCAL_WORLD_SIZE
if local_world_size != _LOCAL_WORLD_SIZE:
_LOCAL_WORLD_SIZE = local_world_size
os.environ["LOCAL_WORLD_SIZE"] = str(local_world_size)
else:
raise ValueError(f"World size {local_world_size} is already set")
[docs]def get_local_world_size_optional() -> int | None:
return _LOCAL_WORLD_SIZE
[docs]def get_local_world_size() -> int:
return 1 if _LOCAL_WORLD_SIZE is None else _LOCAL_WORLD_SIZE
[docs]def set_master_addr(master_addr: str) -> None:
global _MASTER_ADDR
if master_addr != _MASTER_ADDR:
os.environ["MASTER_ADDR"] = _MASTER_ADDR = master_addr
else:
raise ValueError(f"Master address {master_addr} is already set")
[docs]def get_master_addr() -> str:
assert _MASTER_ADDR is not None, "Master address is not yet set"
return _MASTER_ADDR
[docs]def set_master_port(port: int) -> None:
global _MASTER_PORT
if port != _MASTER_PORT:
_MASTER_PORT = port
os.environ["MASTER_PORT"] = str(port)
else:
raise ValueError(f"Master port {port} is already set")
[docs]def get_master_port() -> int:
assert _MASTER_PORT is not None, "Master port is not yet set"
return _MASTER_PORT
[docs]def is_master() -> bool:
return get_rank() == 0
[docs]def is_distributed() -> bool:
return _INIT_METHOD is not None
[docs]def get_init_method() -> str:
assert _INIT_METHOD is not None, "Init method is not yet set"
return _INIT_METHOD
[docs]def set_init_method(init_method: str) -> None:
global _INIT_METHOD
if init_method != _INIT_METHOD:
os.environ["INIT_METHOD"] = _INIT_METHOD = init_method
else:
raise ValueError(f"Init method {init_method} is already set")
[docs]def get_random_port(default: int = 1337) -> int:
try:
return (hash(time.time()) + random.randint(0, 100000)) % (65_535 - 10_000) + 10_000
except Exception:
return default
[docs]def set_dist(
rank: int,
local_rank: int,
world_size: int,
local_world_size: int,
master_addr: str,
master_port: int,
init_method: str,
) -> None:
set_rank(rank)
set_local_rank(local_rank)
set_world_size(world_size)
set_local_world_size(local_world_size)
set_master_addr(master_addr)
set_master_port(master_port)
set_init_method(init_method)