"""Defines checkpoint utility functions.
These functions can be used to load a model from an arbitrary config file
and checkpoint. Note that there might be some issues if you move the checkpoint
around places.
"""
import logging
import re
import tempfile
from pathlib import Path
from typing import Any, TypeVar, cast
import torch
import yaml
from omegaconf import MISSING, Container, DictConfig, OmegaConf
from omegaconf._utils import get_yaml_loader
from torchvision.datasets.utils import download_url
from ml.core.env import get_model_dir, set_ml_config_path
from ml.core.registry import Objects, register_model, register_task
from ml.models.base import BaseModel
from ml.tasks.base import BaseTask
from ml.trainers.base import BaseTrainer
from ml.utils.data import check_md5, check_sha256
from ml.utils.device.auto import detect_device
from ml.utils.timer import Timer
logger = logging.getLogger(__name__)
T = TypeVar("T")
[docs]def is_missing(cfg: Any, key: str) -> bool: # noqa: ANN401
"""Utility function for checking if a config key is missing.
This is for cases when you are using a raw dataclass rather than an
OmegaConf container but want to treat them the same way.
Args:
cfg: The config to check
key: The key to check
Returns:
Whether or not the key is missing a value in the config
"""
if isinstance(cfg, Container) and OmegaConf.is_missing(cfg, key):
return True
if getattr(cfg, key) is MISSING:
return True
return False
[docs]def instantiate_config(config: str | Path | DictConfig | dict) -> Objects:
"""Builds the objects from the raw config.
Args:
config: The config to use. If a string or a Path, it is expected to
be a path to a YAML file.
Returns:
The instantiated objects.
"""
if isinstance(config, (str, Path)):
set_ml_config_path(Path(config))
config = cast(DictConfig, OmegaConf.load(config))
if not OmegaConf.is_dict(config):
raise ValueError(f"Expected config to be a dict, got {type(config)}")
elif isinstance(config, dict):
config = OmegaConf.create(config)
Objects.update_config(config)
Objects.resolve_config(config)
return Objects.parse_raw_config(config)
[docs]def get_checkpoint_path(trainer: BaseTrainer, config_path: str | Path, ckpt_path: str | Path | None) -> Path:
if ckpt_path is not None:
ckpt_path = Path(ckpt_path)
if ckpt_path.exists():
return ckpt_path
logger.warning("Could not find the passed checkpoint at %s", ckpt_path)
# Tries loading the checkpoint that the trainer thinks exists.
ckpt_path = trainer.get_ckpt_path()
if ckpt_path.exists():
return ckpt_path
logger.warning("Could not find trainer checkpoint at %s", ckpt_path)
# Tries loading other checkpoints.
config_path = Path(config_path)
ckpt_path = config_path.parent / "ckpt.pt"
if ckpt_path.exists():
return ckpt_path
logger.warning("Could not find checkpoint at %s", ckpt_path)
# Searches for a checkpoint in the same directory as the config.
ckpt_paths = list(config_path.parent.rglob("ckpt*.pt"))
if ckpt_paths:
return max(ckpt_paths, key=lambda p: p.stat().st_mtime)
logger.warning("Could not find checkpoints in config directory %s", config_path.parent)
raise RuntimeError("Could not find a checkpoint to load")
[docs]def load_model_and_task(
config_path: str | Path | None = None,
ckpt_path: str | Path | None = None,
to_device: bool = True,
missing_ckpt_okay: bool = False,
) -> tuple[BaseModel, BaseTask]:
"""Loads a trained checkpoint from a config, and optional checkpoint path.
Args:
config_path: The path to the config file.
ckpt_path: The path to the checkpoint file; if None, the latest
checkpoint will be used. This defaults to first checking in an
adjacent ``checkpoints`` directory for a ``ckpt.pt`` file, or else
checking for the checkpoint file in the same directory as the
config.
to_device: Whether to move the model to the device specified in the
config.
missing_ckpt_okay: Whether to return a model and task even if the
checkpoint is missing.
Returns:
The model and task loaded from the checkpoint
Raises:
ValueError: If both ``config_path`` and ``ckpt_path`` are None.
RuntimeError: If the checkpoint is missing and `missing_ckpt_okay` is
False.
"""
with Timer("loading checkpoint"):
concrete_ckpt_path: str | Path | None = None
trainer: BaseTrainer
if config_path is None:
if ckpt_path is None:
raise ValueError("Must provide either a config path or a checkpoint path")
ckpt = torch.load(ckpt_path, map_location="cpu")
if "config" not in ckpt:
raise ValueError("Could not find a config in the checkpoint")
concrete_ckpt_path = ckpt_path
set_ml_config_path(Path(ckpt_path).parent / "config.yaml")
config_yaml = yaml.load(ckpt["config"], Loader=get_yaml_loader())
config = OmegaConf.create(config_yaml)
trainer = BaseTrainer(config.trainer)
else:
set_ml_config_path(Path(config_path))
config = cast(DictConfig, OmegaConf.load(config_path))
trainer = BaseTrainer(config.trainer)
# Uses the dummy trainer to load the checkpoint.
try:
concrete_ckpt_path = get_checkpoint_path(trainer, config_path, ckpt_path)
except RuntimeError:
if missing_ckpt_okay:
logger.exception("Could not load checkpoint")
else:
raise
model = register_model.build_entry_non_null(config)
task = register_task.build_entry_non_null(config)
if concrete_ckpt_path is not None:
trainer.load_checkpoint(concrete_ckpt_path, task, model)
if to_device:
device = detect_device()
device.module_to(model, with_dtype=False)
device.module_to(task, with_dtype=False)
return model, task
[docs]def ensure_downloaded(
url: str,
*dnames: str,
md5: str | None = None,
sha256: str | None = None,
is_tmp: bool = False,
recheck_hash: bool = False,
) -> Path:
"""Ensures that a checkpoint URL has been downloaded.
This basically just provides a nice way of organizing pre-trained models,
by saving them to a consistent location.
Args:
url: The URL to download.
dnames: The directory to download to (note that this is relative to the
model directory). The final name should be the file name
md5: The MD5 hash of the file, if known.
sha256: The SHA256 hash of the file, if known.
is_tmp: If set, use ``tmp/`` instead of ``get_model_dir()``
recheck_hash: Whether to recheck the hash of the file if it already
exists.
Returns:
The path to the downloaded file.
"""
assert len(dnames) >= 1, "Must provide at least 1 directory name"
filepath = Path(tempfile.mkdtemp("models")) if is_tmp else get_model_dir()
for dname in dnames:
filepath = filepath / dname
(root := filepath.parent).mkdir(parents=True, exist_ok=True)
def check_hashes() -> bool:
return filepath.is_file() and check_sha256(filepath, sha256) and check_md5(filepath, md5)
def download_file() -> None:
download_url(url, root=root, filename=filepath.name)
assert filepath.is_file(), f"Failed to download {url} to {filepath}"
if not check_hashes():
filepath.unlink()
raise RuntimeError(f"Hashes for {url} do not match")
# If the file does not exist, download it and check the hashes.
if not filepath.exists():
download_file()
# By default, assume the downloaded file hash is correct.
if not recheck_hash:
return filepath
# Check the file hashes again, to ensure the file was not corrupted.
if not check_hashes():
filepath.unlink()
download_file()
return filepath
[docs]def get_state_dict_prefix(
ckpt: dict[str, T],
prefix: str | None = None,
suffix: str | None = None,
regexp: re.Pattern[str] | None = None,
) -> dict[str, T]:
"""Returns the parts of a checkpoint which begin with a prefix.
Args:
ckpt: The checkpoint to modify
prefix: The prefix to clip
suffix: The suffix to clip
regexp: The regexp to search for (doesn't modify any keys)
Returns:
The modified checkpoint
"""
if prefix is not None:
ckpt = {k[len(prefix) :]: v for k, v in ckpt.items() if k.startswith(prefix)}
if suffix is not None:
ckpt = {k[: -len(suffix)]: v for k, v in ckpt.items() if k.endswith(suffix)}
if regexp is not None:
ckpt = {k: v for k, v in ckpt.items() if regexp.match(k)}
return ckpt