Source code for ml.core.config

"""Defines the base config and object classes.

The base configuration dataclass provides some small increments over what
OmegaConf expects. The :func:``conf_field`` function exists to standardize
the metadata that each field should have.
"""

from dataclasses import dataclass, field
from typing import Any, Generic, TypeVar, cast

from omegaconf import MISSING, DictConfig, OmegaConf

FieldType = Any


[docs]def conf_field( value: FieldType, *, help: str | None = None, # pylint: disable=redefined-builtin short: str | None = None, ) -> FieldType: """Gets a field for a given value. Args: value: The default value for the current field help: An optional metadata field, which may be parsed to a command line argument in some CLIs short: An optional metadata field, which may be parsed to a command line argument in some CLIs Returns: The correctly constructed dataclass field """ metadata: dict[str, Any] = {} if help is not None: metadata["help"] = help if short is not None: metadata["short"] = short if hasattr(value, "__call__"): return field(default_factory=value, metadata=metadata) if value.__class__.__hash__ is None: return field(default_factory=lambda: value, metadata=metadata) return field(default=value, metadata=metadata)
BaseConfigT = TypeVar("BaseConfigT", bound="BaseConfig")
[docs]@dataclass class BaseConfig: """Defines the base class for all configs.""" name: str = conf_field(MISSING, short="n", help="The referenced name of the object to construct")
[docs] @classmethod def get_defaults(cls: type[BaseConfigT]) -> dict[str, BaseConfigT]: """Returns default configurations. Returns: A dictionary of default configurations for the current config """ return {}
[docs] @classmethod def update(cls: type[BaseConfigT], config: DictConfig) -> DictConfig: """Runs post-construction config update. Args: config: The config to update """ return cast(DictConfig, OmegaConf.merge(OmegaConf.structured(cls), config))
[docs] @classmethod def resolve(cls: type[BaseConfigT], config: BaseConfigT) -> None: """Runs post-construction config resolution. Args: config: The config to resolve """
[docs]class BaseObject(Generic[BaseConfigT]): """Defines the base class for all objects.""" __constants__ = ["config"] def __init__(self, config: BaseConfigT) -> None: self.config: BaseConfigT = config self._raw_config: DictConfig | None = None @property def raw_config(self) -> DictConfig: if self._raw_config is None: raise RuntimeError("Cannot access raw config yet; it has yet to be assigned") return self._raw_config
[docs] def set_raw_config(self, raw_config: DictConfig) -> None: if self._raw_config is not None: raise RuntimeError("The raw config object was already written") self._raw_config = raw_config