Source code for ml.utils.argparse

"""Utilities for using argparse with dataclasses."""

import argparse
from dataclasses import MISSING, fields
from typing import Any, Type, TypeVar, Union, cast, get_args, get_origin

from omegaconf import OmegaConf

from ml.core.config import BaseConfig

Config = TypeVar("Config", bound=BaseConfig)


[docs]def get_type_from_string(type_name: str) -> Type: """Parses a type name to a string. Args: type_name: The type name to parse. Returns: The type corresponding to the type name. Raises: ValueError: If the type name is not supported. """ if type_name == "str": return str if type_name == "float": return float if type_name == "int": return int raise ValueError(type_name)
[docs]def add_args(parser: argparse.ArgumentParser, dc: Type[Config]) -> None: """Adds arguments to an argument parser from a dataclass. Args: parser: The argument parser to add arguments to. dc: The dataclass to add arguments from. Raises: NotImplementedError: If the dataclass has a field with an unsupported type. """ for field in fields(dc): args: list[str] = [] if field.metadata.get("short") is not None: args.append(f"-{field.metadata['short']}") args.append(f"--{field.name.replace('_', '-')}") kwargs: dict[str, Any] = {} if field.default != MISSING: kwargs["default"] = field.default elif field.default_factory != MISSING: kwargs["default"] = field.default_factory() if field.metadata.get("help") is not None: kwargs["help"] = field.metadata["help"] if field.type in (str, float, int): assert "default" in kwargs, f"Field {field.name} requires default" kwargs["type"] = field.type elif field.type is bool: assert "default" in kwargs, f"Field {field.name} requires default" default_val = cast(bool, kwargs["default"]) kwargs["action"] = "store_false" if default_val else "store_true" elif field.type in ("str", "float", "int"): # type: ignore[comparison-overlap] kwargs["type"] = get_type_from_string(cast(str, field.type)) if "default" in kwargs and not isinstance(kwargs["default"], kwargs["type"]): kwargs.pop("default") kwargs["required"] = True elif "default" not in kwargs: kwargs["required"] = True elif get_origin(field.type) is Union: field_types = set(get_args(field.type)) if type(None) in field_types: field_types.remove(type(None)) if len(field_types) != 1: raise NotImplementedError(f"Field {field.name} has multiple types: {field_types}") field_type = list(field_types)[0] if field_type not in (str, float, int): raise NotImplementedError(f"Field {field.name} has unsupported type: {field_type}") kwargs["type"] = field_type else: raise NotImplementedError(f"Couldn't get type for {field.name}") parser.add_argument(*args, **kwargs)
[docs]def from_args(args: argparse.Namespace, dc: Type[Config]) -> Config: """Creates a dataclass from an argument parser namespace. Args: args: The argument parser namespace to create the dataclass from. dc: The dataclass to create. Returns: The dataclass created from the argument parser namespace. """ values: dict[str, Any] = {} for field in fields(dc): values[field.name] = getattr(args, field.name) cfg = OmegaConf.structured(dc(**values)) dc.resolve(cfg) return cfg