Source code for ml.core.env

"""Defines any core environment variables used in the ML repository.

In order to keep all environment variables in one place, so that they can be
easily referenced, don't use `os.environ` or `os.getenv` outside of this file.
Instead, add a new accessor function to this file.
"""

import os
from pathlib import Path


[docs]class StrEnvVar: def __init__(self, key: str, *, default: str | None = None) -> None: self.key = key self.default = default
[docs] def get(self, *, allow_default: bool = True) -> str: value = self.maybe_get(allow_default=allow_default) if value is None: raise KeyError(f"Value for {self.key} environment variable is not set") return value
[docs] def maybe_get(self, *, allow_default: bool = True) -> str | None: if self.key in os.environ: return os.environ[self.key] if allow_default: return self.default return None
[docs] def set(self, value: str) -> None: os.environ[self.key] = value
[docs]class StrSetEnvVar: def __init__(self, key: str, *, sep: str = ",") -> None: self.key = key self.sep = sep
[docs] def get(self) -> set[str]: return {v for v in os.environ.get(self.key, "").split(self.key) if v}
[docs] def set(self, values: set[str]) -> None: os.environ[self.key] = self.sep.join(v for v in sorted(values) if v)
[docs] def add(self, value: str) -> None: self.set(self.get() | {value})
[docs]class BoolEnvVar: def __init__(self, key: str, default: bool = False) -> None: self.key = key self.default = default
[docs] def get(self, *, allow_default: bool = True) -> bool: if self.key in os.environ: return bool(int(os.environ[self.key])) if allow_default: return self.default raise KeyError(f"Value for {self.key} environment variable is not set")
[docs] def set(self, val: bool) -> None: os.environ[self.key] = "1" if val else "0"
[docs]class IntEnvVar: def __init__(self, key: str, *, default: int | None = None) -> None: self.key = key self.default = default
[docs] def get(self, *, allow_default: bool = True) -> int: value = self.maybe_get(allow_default=allow_default) if value is None: raise KeyError(f"Value for {self.key} environment variable is not set") return value
[docs] def maybe_get(self, *, allow_default: bool = True) -> int | None: if self.key in os.environ: return int(os.environ[self.key]) if allow_default: return self.default return None
[docs] def set(self, value: int) -> None: os.environ[self.key] = str(value)
[docs]class PathEnvVar: def __init__(self, key: str, *, default: Path | None = None) -> None: self.key = key self.default = default
[docs] def get(self, *, allow_default: bool = True) -> Path: value = self.maybe_get(allow_default=allow_default) if value is None: raise KeyError(f"Value for {self.key} environment variable is not set") return value
[docs] def maybe_get(self, *, allow_default: bool = True) -> Path | None: if self.key in os.environ: return Path(os.environ[self.key]).resolve() if allow_default: return self.default return None
[docs] def set(self, value: Path) -> None: os.environ[self.key] = str(value.resolve())
# Option to toggle debug mode (turns off dataloader multiprocessing, improves logging). Debugging = BoolEnvVar("DEBUG") is_debugging = Debugging.get # Where to store miscellaneous cache artifacts. CacheDir = PathEnvVar("CACHE_DIR", default=Path.home() / ".cache" / "ml-starter" / "model-artifacts") get_cache_dir = CacheDir.get # Root directory for training runs. RunDir = PathEnvVar("RUN_DIR", default=Path.cwd() / "runs") get_run_dir = RunDir.get set_run_dir = RunDir.set # Root directory for evaluation runs. EvalRunDir = PathEnvVar("EVAL_RUN_DIR", default=Path.cwd() / "evals") get_eval_run_dir = EvalRunDir.get set_eval_run_dir = EvalRunDir.set # The name of the experiment (set by the training script). ExpName = StrEnvVar("EXPERIMENT_NAME", default="Experiment") get_exp_name = ExpName.get set_exp_name = ExpName.set # Base directory where various datasets are stored. DataDir = PathEnvVar("DATA_DIR", default=Path.home() / ".cache" / "ml-starter" / "datasets") get_data_dir = DataDir.get set_data_dir = DataDir.set # Slurm configuration file path. SlurmConfPath = PathEnvVar("ML_SLURM_CONF", default=Path.home() / ".slurm.yaml") get_slurm_conf_path = SlurmConfPath.get # S3 bucket where various datasets are stored. S3DataBucket = StrEnvVar("S3_DATA_BUCKET") get_s3_data_bucket = S3DataBucket.get set_s3_data_bucket = S3DataBucket.set # S3 bucket where runs are stored. S3RunsBucket = StrEnvVar("S3_RUNS_BUCKET") get_s3_runs_bucket = S3RunsBucket.get set_s3_runs_bucket = S3RunsBucket.set # Base directory where various pretrained models are stored. ModelDir = PathEnvVar("MODEL_DIR", default=Path.home() / ".cache" / "ml-starter" / "models") get_model_dir = ModelDir.get set_model_dir = ModelDir.set # The global random seed. RandomSeed = IntEnvVar("RANDOM_SEED", default=1337) get_env_random_seed = RandomSeed.get set_env_random_seed = RandomSeed.set # Directory where code is staged before running large-scale experiments. StageDir = PathEnvVar("STAGE_DIR", default=Path.home() / ".cache" / "ml-starter" / "staging") get_stage_dir = StageDir.get set_stage_dir = StageDir.set # Global experiment tags (used for the experiment name, among other things). GlobalTags = StrSetEnvVar("GLOBAL_MODEL_TAGS") get_global_tags = GlobalTags.get set_global_tags = GlobalTags.set add_global_tag = GlobalTags.add # Disables using accelerator on Mac. DisableMetal = BoolEnvVar("DISABLE_METAL", default=False) is_metal_disabled = DisableMetal.get # Disables using the GPU. DisableGPU = BoolEnvVar("DISABLE_GPU", default=False) is_gpu_disabled = DisableGPU.get # Disables colors in various parts. DisableColors = BoolEnvVar("DISABLE_COLORS", default=False) are_colors_disabled = DisableColors.get # Disables Tensorboard subprocess. DisableTensorboard = BoolEnvVar("DISABLE_TENSORBOARD", default=False) is_tensorboard_disabled = DisableTensorboard.get # Show full error message when trying to import a file. ShowFullImportError = BoolEnvVar("SHOW_FULL_IMPORT_ERROR", default=False) should_show_full_import_error = ShowFullImportError.get # The path to the resolved config. MLConfigPath = PathEnvVar("ML_CONFIG_PATH") get_ml_config_path = MLConfigPath.maybe_get set_ml_config_path = MLConfigPath.set # Show all logs for matplotlib, PIL, torch, etc. ShowAllLogs = BoolEnvVar("SHOW_ALL_LOGS", default=False) should_show_all_logs = ShowAllLogs.get # Ignore the cache file when looking for modules. IgnoreRegistryCache = BoolEnvVar("IGNORE_REGISTRY_CACHE", default=False) ignore_registry_cache = IgnoreRegistryCache.get # The Weights & Biases entity. WandbEntity = StrEnvVar("WANDB_ENTITY") get_wandb_entity = WandbEntity.maybe_get # Path to the default config files. DefaultConfigRootPath = PathEnvVar("DEFAULT_CONFIG_PATH", default=Path.home() / ".config" / "ml") get_default_config_root_path = DefaultConfigRootPath.get