ml.utils.checkpoint
Defines checkpoint utility functions.
These functions can be used to load a model from an arbitrary config file and checkpoint. Note that there might be some issues if you move the checkpoint around places.
- ml.utils.checkpoint.is_missing(cfg: Any, key: str) bool [source]
Utility function for checking if a config key is missing.
This is for cases when you are using a raw dataclass rather than an OmegaConf container but want to treat them the same way.
- Parameters:
cfg – The config to check
key – The key to check
- Returns:
Whether or not the key is missing a value in the config
- ml.utils.checkpoint.instantiate_config(config: str | Path | DictConfig | dict) Objects [source]
Builds the objects from the raw config.
- Parameters:
config – The config to use. If a string or a Path, it is expected to be a path to a YAML file.
- Returns:
The instantiated objects.
- ml.utils.checkpoint.get_checkpoint_path(trainer: BaseTrainer, config_path: str | Path, ckpt_path: str | Path | None) Path [source]
- ml.utils.checkpoint.load_model_and_task(config_path: str | Path | None = None, ckpt_path: str | Path | None = None, to_device: bool = True, missing_ckpt_okay: bool = False) tuple[ml.models.base.BaseModel, ml.tasks.base.BaseTask] [source]
Loads a trained checkpoint from a config, and optional checkpoint path.
- Parameters:
config_path – The path to the config file.
ckpt_path – The path to the checkpoint file; if None, the latest checkpoint will be used. This defaults to first checking in an adjacent
checkpoints
directory for ackpt.pt
file, or else checking for the checkpoint file in the same directory as the config.to_device – Whether to move the model to the device specified in the config.
missing_ckpt_okay – Whether to return a model and task even if the checkpoint is missing.
- Returns:
The model and task loaded from the checkpoint
- Raises:
ValueError – If both
config_path
andckpt_path
are None.RuntimeError – If the checkpoint is missing and missing_ckpt_okay is False.
- ml.utils.checkpoint.ensure_downloaded(url: str, *dnames: str, md5: str | None = None, sha256: str | None = None, is_tmp: bool = False, recheck_hash: bool = False) Path [source]
Ensures that a checkpoint URL has been downloaded.
This basically just provides a nice way of organizing pre-trained models, by saving them to a consistent location.
- Parameters:
url – The URL to download.
dnames – The directory to download to (note that this is relative to the model directory). The final name should be the file name
md5 – The MD5 hash of the file, if known.
sha256 – The SHA256 hash of the file, if known.
is_tmp – If set, use
tmp/
instead ofget_model_dir()
recheck_hash – Whether to recheck the hash of the file if it already exists.
- Returns:
The path to the downloaded file.
- ml.utils.checkpoint.get_state_dict_prefix(ckpt: dict[str, T], prefix: str | None = None, suffix: str | None = None, regexp: Pattern[str] | None = None) dict[str, T] [source]
Returns the parts of a checkpoint which begin with a prefix.
- Parameters:
ckpt – The checkpoint to modify
prefix – The prefix to clip
suffix – The suffix to clip
regexp – The regexp to search for (doesn’t modify any keys)
- Returns:
The modified checkpoint