ml.utils.checkpoint

Defines checkpoint utility functions.

These functions can be used to load a model from an arbitrary config file and checkpoint. Note that there might be some issues if you move the checkpoint around places.

ml.utils.checkpoint.is_missing(cfg: Any, key: str) bool[source]

Utility function for checking if a config key is missing.

This is for cases when you are using a raw dataclass rather than an OmegaConf container but want to treat them the same way.

Parameters:
  • cfg – The config to check

  • key – The key to check

Returns:

Whether or not the key is missing a value in the config

ml.utils.checkpoint.instantiate_config(config: str | Path | DictConfig | dict) Objects[source]

Builds the objects from the raw config.

Parameters:

config – The config to use. If a string or a Path, it is expected to be a path to a YAML file.

Returns:

The instantiated objects.

ml.utils.checkpoint.get_checkpoint_path(trainer: BaseTrainer, config_path: str | Path, ckpt_path: str | Path | None) Path[source]
ml.utils.checkpoint.load_model_and_task(config_path: str | Path | None = None, ckpt_path: str | Path | None = None, to_device: bool = True, missing_ckpt_okay: bool = False) tuple[ml.models.base.BaseModel, ml.tasks.base.BaseTask][source]

Loads a trained checkpoint from a config, and optional checkpoint path.

Parameters:
  • config_path – The path to the config file.

  • ckpt_path – The path to the checkpoint file; if None, the latest checkpoint will be used. This defaults to first checking in an adjacent checkpoints directory for a ckpt.pt file, or else checking for the checkpoint file in the same directory as the config.

  • to_device – Whether to move the model to the device specified in the config.

  • missing_ckpt_okay – Whether to return a model and task even if the checkpoint is missing.

Returns:

The model and task loaded from the checkpoint

Raises:
  • ValueError – If both config_path and ckpt_path are None.

  • RuntimeError – If the checkpoint is missing and missing_ckpt_okay is False.

ml.utils.checkpoint.ensure_downloaded(url: str, *dnames: str, md5: str | None = None, sha256: str | None = None, is_tmp: bool = False, recheck_hash: bool = False) Path[source]

Ensures that a checkpoint URL has been downloaded.

This basically just provides a nice way of organizing pre-trained models, by saving them to a consistent location.

Parameters:
  • url – The URL to download.

  • dnames – The directory to download to (note that this is relative to the model directory). The final name should be the file name

  • md5 – The MD5 hash of the file, if known.

  • sha256 – The SHA256 hash of the file, if known.

  • is_tmp – If set, use tmp/ instead of get_model_dir()

  • recheck_hash – Whether to recheck the hash of the file if it already exists.

Returns:

The path to the downloaded file.

ml.utils.checkpoint.get_state_dict_prefix(ckpt: dict[str, T], prefix: str | None = None, suffix: str | None = None, regexp: Pattern[str] | None = None) dict[str, T][source]

Returns the parts of a checkpoint which begin with a prefix.

Parameters:
  • ckpt – The checkpoint to modify

  • prefix – The prefix to clip

  • suffix – The suffix to clip

  • regexp – The regexp to search for (doesn’t modify any keys)

Returns:

The modified checkpoint