Source code for ml.lr_schedulers.scripts.plot

# mypy: disable-error-code="import"
"""A simple script for plotting learning rate schedules with various parameters.

This script can be used as follows:

.. code-block:: bash

    python -m ml.lr_schedulers.scripts.plot linear /path/to/save.png
"""

import argparse
from pathlib import Path

from ml.core.registry import register_lr_scheduler
from ml.core.state import State
from ml.utils.argparse import add_args, from_args


[docs]def main() -> None: """Plots a learning rate schedule.""" try: import matplotlib.pyplot as plt except ModuleNotFoundError as e: raise ImportError("Please install matplotlib to use this script: `pip install matplotlib`") from e # Gets the plotting-specific arguments. parser = argparse.ArgumentParser(description="Plots a learning rate schedule") parser.add_argument("lr_scheduler", help="Which scheduler to plot") parser.add_argument("save_path", nargs="?", help="Where to save the plot") parser.add_argument("-n", "--num-iters", type=int, default=100_000, help="Number of iterations") parser.add_argument("-s", "--stride", type=int, default=100, help="Stride between iterations") args, cli_args = parser.parse_known_args() save_path = None if args.save_path is None else Path(args.save_path) num_iters: int = args.num_iters stride: int = args.stride # Parses config-specific arguments and builds the learning rate scheduler. scheduler_cls, scheduler_config_cls = register_lr_scheduler.lookup(args.lr_scheduler) scheduler_parser = argparse.ArgumentParser(description=f"Parser for {args.lr_scheduler} scheduler") add_args(scheduler_parser, scheduler_config_cls) scheduler_args = scheduler_parser.parse_args(cli_args) scheduler_config = from_args(scheduler_args, scheduler_config_cls) scheduler = scheduler_cls(scheduler_config) state = State.init_state() xs: list[float] = [] ys: list[float] = [] for i in range(0, num_iters, stride): state.num_steps = i lr_scale = scheduler.get_lr_scale(state) xs.append(i) ys.append(lr_scale) plt.figure() plt.plot(xs, ys) if save_path is None: plt.show() else: save_path.parent.mkdir(exist_ok=True, parents=True) plt.savefig(save_path)
if __name__ == "__main__": # python -m ml.lr_schedulers.scripts.plot linear /path/to/save.png main()