Source code for ml.lr_schedulers.cosine

"""Defines a cosine learning rate scheduler."""

import math
from dataclasses import dataclass

from omegaconf import II, MISSING, OmegaConf

from ml.core.config import conf_field
from ml.core.registry import register_lr_scheduler
from ml.core.state import State
from ml.lr_schedulers.base import BaseLRScheduler, BaseLRSchedulerConfig


[docs]@dataclass class CosineLRSchedulerConfig(BaseLRSchedulerConfig): total_steps: int = conf_field(II("task.max_steps"), help="Total number of steps to run") num_resets: int = conf_field(0, help="Number of times to reset learning") phase: int = conf_field(MISSING, help="Number of steps in a phase") ramp_up_percent: float = conf_field(0.05, help="Percent of phase to spend ramping up") ramp_up_steps: int = conf_field(MISSING, help="Number of steps to spend ramping up") eta_min: float = conf_field(0.01, help="Minimum learning rate scale") eta_max: float = conf_field(1.0, help="Maximum learning rate scale")
[docs] @classmethod def resolve(cls, config: "CosineLRSchedulerConfig") -> None: if OmegaConf.is_missing(config, "phase"): config.phase = int(config.total_steps / (config.num_resets + 1)) if OmegaConf.is_missing(config, "ramp_up_steps"): assert 0 <= config.ramp_up_percent < 1 config.ramp_up_steps = int(config.phase * config.ramp_up_percent) else: assert config.ramp_up_steps < config.phase return super().resolve(config)
[docs]@register_lr_scheduler("cosine", CosineLRSchedulerConfig) class CosineLRScheduler(BaseLRScheduler[CosineLRSchedulerConfig]):
[docs] def get_lr_scale(self, state: State) -> float: phase, ramp_up = self.config.phase, self.config.ramp_up_steps eta_min, eta_max = self.config.eta_min, self.config.eta_max phase_steps = state.num_steps % (phase + ramp_up) if phase_steps < ramp_up: return (1.0 - eta_min) * (phase_steps / ramp_up) + eta_min sigma = (phase_steps - ramp_up) / phase return eta_min + (eta_max - eta_min) * (1 + math.cos(math.pi * sigma)) / 2