Source code for ml.core.state

"""Defines a dataclass for keeping track of the current training state."""

import time
from dataclasses import dataclass
from typing import Literal, TypeVar, cast, get_args

from omegaconf import MISSING
from torch import nn

from ml.core.config import conf_field

Module = TypeVar("Module", bound=nn.Module)

Phase = Literal["train", "valid", "test"]


[docs]def set_phase(model: Module, phase: Phase) -> tuple[Module, Phase]: if phase == "train": if not model.training: model = model.train() return model, phase else: if model.training: model = model.eval() return model, phase
[docs]def cast_phase(raw_phase: str) -> Phase: args = get_args(Phase) assert raw_phase in args, f"Invalid phase: '{raw_phase}' Valid options are {args}" return cast(Phase, raw_phase)
[docs]@dataclass class State: num_epochs: int = conf_field(MISSING, help="Number of epochs so far") num_steps: int = conf_field(MISSING, help="Number of steps so far") num_epoch_steps: int = conf_field(MISSING, help="Number of steps in the current epoch") num_samples: int = conf_field(MISSING, help="Number of sample so far") num_epoch_samples: int = conf_field(MISSING, help="Number of samples in the current epoch") num_valid_steps: int = conf_field(MISSING, help="Number of validation steps so far") num_test_steps: int = conf_field(MISSING, help="Number of test steps so far") start_time_s: float = conf_field(MISSING, help="Start time of training") elapsed_time_s: float = conf_field(MISSING, help="Total elapsed time so far") raw_phase: str = conf_field(MISSING, help="Current training phase") @property def phase(self) -> Phase: return cast_phase(self.raw_phase) @phase.setter def phase(self, new_phase: Phase) -> None: self.raw_phase = new_phase
[docs] @classmethod def init_state(cls) -> "State": return cls( num_epochs=0, num_steps=0, num_epoch_steps=0, num_samples=0, num_epoch_samples=0, num_valid_steps=0, num_test_steps=0, start_time_s=time.time(), elapsed_time_s=0.0, raw_phase="train", )
@property def training(self) -> bool: return self.phase == "train"
[docs] def num_phase_steps(self, phase: Phase) -> int: match phase: case "train": return self.num_steps case "valid": return self.num_valid_steps case "test": return self.num_test_steps