Source code for ml.trainers.mixins.monitor_process

"""Defines a base trainer mixin for handling subprocess monitoring jobs."""

import logging
import multiprocessing as mp
from dataclasses import dataclass
from typing import Generic, TypeVar

from torch.optim.optimizer import Optimizer

from ml.core.state import State
from ml.lr_schedulers.base import SchedulerAdapter
from ml.trainers.base import BaseTrainer, BaseTrainerConfig, ModelT, TaskT

logger: logging.Logger = logging.getLogger(__name__)


[docs]@dataclass class MonitorProcessConfig(BaseTrainerConfig): pass
MonitorProcessConfigT = TypeVar("MonitorProcessConfigT", bound=MonitorProcessConfig)
[docs]class MonitorProcessMixin( BaseTrainer[MonitorProcessConfigT, ModelT, TaskT], Generic[MonitorProcessConfigT, ModelT, TaskT], ): """Defines a base trainer mixin for handling monitoring processes.""" def __init__(self, config: MonitorProcessConfigT) -> None: super().__init__(config) self._mp_manager = mp.Manager()
[docs] def on_training_start( self, state: State, task: TaskT, model: ModelT, optim: Optimizer | dict[str, Optimizer], lr_sched: SchedulerAdapter | dict[str, SchedulerAdapter], ) -> None: super().on_training_start(state, task, model, optim, lr_sched) self._mp_manager = mp.Manager()
[docs] def on_training_end( self, state: State, task: TaskT, model: ModelT, optim: Optimizer | dict[str, Optimizer], lr_sched: SchedulerAdapter | dict[str, SchedulerAdapter], ) -> None: super().on_training_end(state, task, model, optim, lr_sched) self._mp_manager.shutdown()