"""A trainer mixin for logging CPU statistics.
This logs memory and CPU utilization in a background process, sending it to
the logging process every now and then. This is useful for detecting memory
leaks in your dataloader, among other issues.
"""
import logging
import multiprocessing as mp
import os
import time
from ctypes import Structure, c_double, c_uint16, c_uint64
from dataclasses import dataclass
from multiprocessing.managers import SyncManager, ValueProxy
from multiprocessing.synchronize import Event
from typing import TypeVar
import psutil
from torch.optim.optimizer import Optimizer
from ml.core.config import conf_field
from ml.core.state import State
from ml.lr_schedulers.base import SchedulerAdapter
from ml.trainers.base import ModelT, TaskT
from ml.trainers.mixins.monitor_process import MonitorProcessConfig, MonitorProcessMixin
logger: logging.Logger = logging.getLogger(__name__)
[docs]@dataclass
class CPUStatsConfig(MonitorProcessConfig):
cpu_stats_ping_interval: int = conf_field(1, help="How often to check stats (in seconds)")
cpu_stats_only_log_once: bool = conf_field(False, help="If set, only log read stats one time")
CPUStatsConfigT = TypeVar("CPUStatsConfigT", bound=CPUStatsConfig)
[docs]class CPUStats(Structure):
_fields_ = [
("cpu_percent", c_double),
("mem_percent", c_double),
("mem_rss", c_uint64),
("mem_vms", c_uint64),
("mem_shared", c_uint64),
("mem_rss_total", c_uint64),
("mem_vms_total", c_uint64),
("child_cpu_percent", c_double),
("child_mem_percent", c_double),
("num_child_procs", c_uint16),
]
[docs]@dataclass
class CPUStatsInfo:
cpu_percent: float
mem_percent: float
mem_rss: int
mem_vms: int
mem_shared: int
mem_rss_total: int
mem_vms_total: int
child_cpu_percent: float
child_mem_percent: float
num_child_procs: int
[docs] @classmethod
def from_stats(cls, stats: CPUStats) -> "CPUStatsInfo":
return cls(
cpu_percent=stats.cpu_percent,
mem_percent=stats.mem_percent,
mem_rss=stats.mem_rss,
mem_vms=stats.mem_vms,
mem_shared=stats.mem_shared,
mem_rss_total=stats.mem_rss_total,
mem_vms_total=stats.mem_vms_total,
child_cpu_percent=stats.child_cpu_percent,
child_mem_percent=stats.child_mem_percent,
num_child_procs=stats.num_child_procs,
)
[docs]def worker(
ping_interval: float,
stats: ValueProxy[CPUStats],
monitor_event: Event,
start_event: Event,
pid: int,
) -> None:
start_event.set()
proc, cur_pid = psutil.Process(pid), os.getpid()
logger.debug("Starting CPU stats monitor for PID %d with PID %d", pid, cur_pid)
def get_children() -> dict[int, psutil.Process]:
return {p.pid: p for p in proc.children(recursive=True) if p.pid != cur_pid}
child_procs = get_children()
while True:
try:
# Updates child processes, preserving the previous child process
# object. Otherwise the CPU percentage will be zero.
new_procs = get_children()
child_procs = {**new_procs, **child_procs}
child_procs = {pid: child_procs[pid] for pid in new_procs.keys()}
# Gets process memory info.
mem_info = proc.memory_info()
mem_rss_total = sum(p.memory_info().rss for p in child_procs.values()) + mem_info.rss
mem_vms_total = sum(p.memory_info().vms for p in child_procs.values()) + mem_info.vms
# Gets child CPU and memory percentages.
child_cpu_percent_total = sum(p.cpu_percent() for p in child_procs.values()) if child_procs else 0.0
child_mem_percent_total = sum(p.memory_percent() for p in child_procs.values()) if child_procs else 0.0
# Sets the CPU stats.
stats.set(
CPUStats(
cpu_percent=proc.cpu_percent(),
mem_percent=proc.memory_percent(),
mem_rss=int(mem_info.rss),
mem_vms=int(mem_info.vms),
mem_shared=int(getattr(mem_info, "shared", 0)),
mem_rss_total=int(mem_rss_total),
mem_vms_total=int(mem_vms_total),
child_cpu_percent=child_cpu_percent_total / len(child_procs),
child_mem_percent=child_mem_percent_total / len(child_procs),
num_child_procs=len(child_procs),
),
)
monitor_event.set()
except psutil.NoSuchProcess:
logger.info("No parent process; probably cleaning up")
time.sleep(ping_interval)
[docs]class CPUStatsMonitor:
def __init__(self, ping_interval: float, manager: SyncManager) -> None:
self._ping_interval = ping_interval
self._manager = manager
self._monitor_event = self._manager.Event()
self._start_event = self._manager.Event()
self._cpu_stats_smem = self._manager.Value(
CPUStats,
CPUStats(
cpu_percent=0.0,
mem_percent=0.0,
mem_rss=0,
mem_vms=0,
mem_shared=0,
mem_rss_total=0,
mem_vms_total=0,
child_cpu_percent=0.0,
child_mem_percent=0.0,
num_child_procs=0,
),
)
self._cpu_stats: CPUStatsInfo | None = None
self._proc: mp.Process | None = None
[docs] def get_if_set(self) -> CPUStatsInfo | None:
if self._monitor_event.is_set():
self._monitor_event.clear()
return CPUStatsInfo.from_stats(self._cpu_stats_smem.get())
return None
[docs] def get(self) -> CPUStatsInfo | None:
if (stats := self.get_if_set()) is not None:
self._cpu_stats = stats
return self._cpu_stats
[docs] def start(self, wait: bool = False) -> None:
if self._proc is not None:
raise RuntimeError("CPU stats monitor already started")
if self._monitor_event.is_set():
self._monitor_event.clear()
if self._start_event.is_set():
self._start_event.clear()
self._cpu_stats = None
self._proc = mp.Process(
target=worker,
args=(self._ping_interval, self._cpu_stats_smem, self._monitor_event, self._start_event, os.getpid()),
daemon=True,
)
self._proc.start()
if wait:
self._start_event.wait()
[docs] def stop(self) -> None:
if self._proc is None:
raise RuntimeError("CPU stats monitor not started")
if self._proc.is_alive():
self._proc.terminate()
logger.debug("Terminated CPU stats monitor; joining...")
self._proc.join()
self._proc = None
self._cpu_stats = None
[docs]class CPUStatsMixin(MonitorProcessMixin[CPUStatsConfigT, ModelT, TaskT]):
"""Defines a trainer mixin for getting CPU statistics."""
def __init__(self, config: CPUStatsConfigT) -> None:
super().__init__(config)
self._cpu_stats_monitor = CPUStatsMonitor(
ping_interval=self.config.cpu_stats_ping_interval,
manager=self._mp_manager,
)
[docs] def on_training_start(
self,
state: State,
task: TaskT,
model: ModelT,
optim: Optimizer | dict[str, Optimizer],
lr_sched: SchedulerAdapter | dict[str, SchedulerAdapter],
) -> None:
super().on_training_start(state, task, model, optim, lr_sched)
self._cpu_stats_monitor.start()
[docs] def on_training_end(
self,
state: State,
task: TaskT,
model: ModelT,
optim: Optimizer | dict[str, Optimizer],
lr_sched: SchedulerAdapter | dict[str, SchedulerAdapter],
) -> None:
super().on_training_end(state, task, model, optim, lr_sched)
self._cpu_stats_monitor.stop()
[docs] def on_step_start(
self,
state: State,
task: TaskT,
model: ModelT,
optim: Optimizer | dict[str, Optimizer],
lr_sched: SchedulerAdapter | dict[str, SchedulerAdapter],
) -> None:
super().on_step_start(state, task, model, optim, lr_sched)
monitor = self._cpu_stats_monitor
stats = monitor.get_if_set() if self.config.cpu_stats_only_log_once else monitor.get()
if stats is not None:
self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="💻 cpu")
self.logger.log_scalar("percent", stats.cpu_percent, namespace="💻 cpu")
self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="💻 cpu")
self.logger.log_scalar("percent", stats.mem_percent, namespace="💻 mem")
self.logger.log_scalar("shared", stats.mem_shared, namespace="💻 mem")
self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="💻 mem")
self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="💻 mem")
self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="💻 mem")
self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="💻 mem")
self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="💻 mem")