Source code for ml.tasks.datasets.error_handling

"""Defines error handling wrappers for datasets.

The worst feeling in the world is when you're training a model and it crashes
after 10 hours of training. This module defines some error handling wrappers
for datasets which will catch errors and log them (in batches).
"""

import bdb
import logging
import random
import sys
import time
from collections import Counter
from dataclasses import dataclass
from typing import Iterator, TypeVar, no_type_check

from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe
from torch.utils.data.dataset import Dataset, IterableDataset

from ml.core.config import conf_field
from ml.utils.colors import colorize
from ml.utils.data import get_worker_info

logger: logging.Logger = logging.getLogger(__name__)

BatchT = TypeVar("BatchT")
DatasetT = TypeVar("DatasetT", Dataset, IterableDataset, MapDataPipe, IterDataPipe)


[docs]def get_loc(num_excs: int = 1) -> str: _, _, exc_tb = sys.exc_info() if exc_tb is None or (exc_tb := exc_tb.tb_next) is None: return "unknown" exc_strs: list[str] = [] for _ in range(num_excs): exc_strs += [f"{exc_tb.tb_frame.f_code.co_filename}:{exc_tb.tb_lineno}"] if (exc_tb := exc_tb.tb_next) is None: break return "\n".join(exc_strs)
[docs]@dataclass class ErrorHandlingConfig: enabled: bool = conf_field(True, help="Is error handling enabled?") maximum_exceptions: int = conf_field(10, help="Maximum number of errors to encounter") backoff_after: int = conf_field(5, help="Start to do a sleeping backoff after this many exceptions") sleep_backoff: float = conf_field(0.1, help="Sleep backoff amount") sleep_backoff_power: float = conf_field(2.0, help="How much to multiply backoff for each successive exception") log_full_exception: bool = conf_field(False, help="Log the full exception message for each exception") flush_exception_summary_every: int = conf_field(500, help="How often to flush exception summary") report_top_n_exception_types: int = conf_field(5, help="Number of exceptions to summarize") exception_location_traceback_depth: int = conf_field(3, help="Traceback length for the exception location")
[docs]class ExceptionSummary: def __init__(self, config: ErrorHandlingConfig) -> None: self.steps = 0 self.step_has_error = False self.total_exceptions = 0 self.flush_every = config.flush_exception_summary_every self.summary_length = config.report_top_n_exception_types self.exceptions: Counter[str] = Counter() self.exception_classes: Counter[str] = Counter() self.exception_locs: Counter[str] = Counter() self.last_exception: Exception | None = None
[docs] def add_exception(self, exc: Exception, loc: str) -> None: self.last_exception = exc self.exceptions[f"{exc.__class__.__name__}: {exc}"] += 1 self.exception_classes[exc.__class__.__name__] += 1 self.exception_locs[loc] += 1 if not self.step_has_error: self.total_exceptions += 1 self.step_has_error = True
[docs] def step(self) -> None: if self.steps >= self.flush_every: self.flush() self.steps += 1 self.step_has_error = False
[docs] def summary(self) -> str: lines: list[str] = [] def get_segment_header(header: str) -> list[str]: header = colorize(f"{header:60s}", "yellow", bold=True) count = colorize(f"{'Count':10s}", "yellow", bold=False) percent = colorize(f"{'Percent':10s}", "yellow", bold=False) return [f"│ {header}{count}{percent} │"] def get_log_line(ks: str, v: int, as_red: bool = False) -> str: chunks = [k[i : i + 60] for k in ks.split("\n") for i in range(0, len(k), 60)] v_int, v_prct = f"{v}", f"{int(v * 100 / self.steps)} %" c = colorize(f"{chunks[0]:60s}", "red", bold=True) if as_red else f"{chunks[0]:60s}" log_lines = [f"│ {c}{v_int:10s}{v_prct:10s} │"] for chunk in chunks[1:]: c = colorize(f"{chunk:60s}", "red", bold=True) if as_red else f"{chunk:60s}" log_lines += [f"│ {c}{'':10s}{'':10s} │"] return "\n".join(log_lines) def get_single_log_line(ks: str) -> str: chunks = [k[i : i + 80] for k in ks.split("\n") for i in range(0, len(k), 82)] c = colorize(f"{chunks[0]:86s}", "red", bold=True) log_lines = [f"│ {c} │"] for chunk in chunks[1:]: c = colorize(f"{chunk:86s}", "red", bold=True) log_lines += [f"│ {c} │"] return "\n".join(log_lines) def get_line_break() -> str: return f"├─{'─' * 60}─┼─{'─' * 10}─┼─{'─' * 10}─┤" def get_line_start() -> str: return f"┌─{'─' * 60}─┬─{'─' * 10}─┬─{'─' * 10}─┐" def get_line_break_before_single() -> str: return f"├─{'─' * 60}─┴─{'─' * 10}─┴─{'─' * 10}─┤" def get_line_end() -> str: return f"└─{'─' * 60}───{'─' * 10}───{'─' * 10}─┘" # Logs the unique exception strings. lines += [get_line_start()] lines += get_segment_header("Error Messages") for k, v in self.exceptions.most_common(self.summary_length): lines += [get_log_line(k, v)] # Logs the individual exception classes. lines += [get_line_break()] lines += get_segment_header("Error Types") for k, v in self.exception_classes.most_common(self.summary_length): lines += [get_log_line(k, v)] # Logs by line number. lines += [get_line_break()] lines += get_segment_header("Error Locations") for k, v in self.exception_locs.most_common(self.summary_length): lines += [get_log_line(k, v)] # Logs the total number of exceptions. error_line = ( f"Error Rate: {self.total_exceptions} failed / {self.steps} total " f"({self.total_exceptions / self.steps * 100:.2f} %)" ) lines += [get_line_break_before_single()] lines += [get_single_log_line(error_line)] lines += [get_line_end()] return "\n".join(lines)
[docs] def flush(self) -> None: worker_info = get_worker_info() if worker_info.worker_id == 0 and self.total_exceptions > 0: logger.info("Exception summary:\n\n%s\n", self.summary()) self.exceptions.clear() self.exception_classes.clear() self.exception_locs.clear() self.steps = 0 self.total_exceptions = 0
[docs]class ErrorHandlingDataset(Dataset[BatchT]): """Defines a wrapper for safely handling errors.""" def __init__(self, dataset: Dataset[BatchT], config: ErrorHandlingConfig) -> None: super().__init__() self.dataset = dataset self.config = config self.exc_summary = ExceptionSummary(config) def __getitem__(self, index: int) -> BatchT: num_exceptions = 0 backoff_time = self.config.sleep_backoff self.exc_summary.step() while num_exceptions < self.config.maximum_exceptions: try: return self.dataset[index] except bdb.BdbQuit as e: logger.info("User interrupted debugging session; aborting") raise e except Exception as e: if self.config.log_full_exception: logger.exception("Caught exception on index %d", index) self.exc_summary.add_exception(e, get_loc(self.config.exception_location_traceback_depth)) index = random.randint(0, len(self) - 1) num_exceptions += 1 if num_exceptions > self.config.backoff_after: logger.error( "Encountered %d exceptions for a single index, backing off for %f seconds", num_exceptions, backoff_time, ) time.sleep(backoff_time) backoff_time *= self.config.sleep_backoff_power exc_message = f"Reached max exceptions {self.config.maximum_exceptions}\n{self.exc_summary.summary()}" if self.exc_summary.last_exception is None: raise RuntimeError(exc_message) raise RuntimeError(exc_message) from self.exc_summary.last_exception def __len__(self) -> int: if hasattr(self.dataset, "__len__"): return self.dataset.__len__() raise NotImplementedError("Base dataset doesn't implemenet `__len__`")
[docs]@functional_datapipe("map_error_handling") class ErrorHandlingMapDataPipe(ErrorHandlingDataset[BatchT], MapDataPipe[BatchT]): """Defines a wrapper for safely handling errors.""" def __init__(self, datapipe: MapDataPipe[BatchT], config: ErrorHandlingConfig) -> None: ErrorHandlingDataset.__init__(self, datapipe, config) MapDataPipe.__init__(self)
[docs]class ErrorHandlingIterableDataset(IterableDataset[BatchT]): """Defines a wrapper for safely handling errors in iterable datasets.""" def __init__(self, dataset: IterableDataset[BatchT], config: ErrorHandlingConfig) -> None: super().__init__() self.iteration = 0 self.dataset = dataset self.config = config self.exc_summary = ExceptionSummary(config) self.iter: Iterator[BatchT] | None = None self._configured_logging = False def __iter__(self) -> Iterator[BatchT]: self.iter = self.dataset.__iter__() self.iteration = 0 return self def __next__(self) -> BatchT: assert self.iter is not None, "Must call `__iter__` before `__next__`" num_exceptions = 0 backoff_time = self.config.sleep_backoff self.exc_summary.step() self.iteration += 1 while num_exceptions < self.config.maximum_exceptions: try: return self.iter.__next__() except bdb.BdbQuit as e: logger.info("User interrupted debugging session; aborting") raise e except StopIteration as e: raise e except Exception as e: if self.config.log_full_exception: logger.exception("Caught exception on iteration %d", self.iteration) self.exc_summary.add_exception(e, get_loc(self.config.exception_location_traceback_depth)) num_exceptions += 1 if num_exceptions > self.config.backoff_after: logger.error( "Encountered %d exceptions for a single index, backing off for %f seconds", num_exceptions, backoff_time, ) time.sleep(backoff_time) backoff_time *= self.config.sleep_backoff_power raise RuntimeError(f"Reached max exceptions {self.config.maximum_exceptions}\n{self.exc_summary.summary()}")
[docs]@functional_datapipe("iter_error_handling") class ErrorHandlingIterDataPipe(ErrorHandlingIterableDataset[BatchT], IterDataPipe[BatchT]): """Defines a wrapper for safely handling errors in iterable datapipe.""" def __init__(self, datapipe: IterDataPipe[BatchT], config: ErrorHandlingConfig) -> None: ErrorHandlingIterableDataset.__init__(self, datapipe, config) IterDataPipe.__init__(self)
[docs]@no_type_check def error_handling_dataset(dataset: DatasetT, config: ErrorHandlingConfig) -> DatasetT: """Returns a dataset which wraps the base dataset and handles errors. Args: dataset: The dataset to handle errors for config: An associated config, describing which errors to handle Returns: The wrapped dataset, which catches some errors Raises: NotImplementedError: If the dataset type is not supported """ if isinstance(dataset, MapDataPipe): return ErrorHandlingMapDataPipe(dataset, config) if isinstance(dataset, IterDataPipe): return ErrorHandlingIterDataPipe(dataset, config) if isinstance(dataset, IterableDataset): return ErrorHandlingIterableDataset(dataset, config) elif isinstance(dataset, Dataset): return ErrorHandlingDataset(dataset, config) raise NotImplementedError(f"Unexpected type: {dataset}")
[docs]def test_exception_summary() -> None: summary = ExceptionSummary(ErrorHandlingConfig()) for i in range(10): try: if i < 7: raise RuntimeError("test") else: raise ValueError("test 2") except Exception as e: summary.add_exception(e, get_loc()) summary.step() print(summary.summary())
if __name__ == "__main__": # python -m ml.tasks.datasets.error_handling test_exception_summary()