Source code for ml.utils.tokens

"""Defines utility functions for dealing with tokens and token datasets.

This file provides helper methods for reading and writing compressed datasets
of tokens. This compresses the tokens into ``ceil(log2(num_tokens))`` bits per
token, with padding at the end of each line to ensure that each line is a
multiple of 8 bits. This optimizes for making the file size as small as
possible while still being efficient to read from.

Here's an example of how to use the API:

.. highlight:: python
.. code-block:: python

    from ml.utils.tokens import TokenReader, TokenWriter

    num_tokens = 6
    file_path = "/path/to/dataset.bin"

    # Write the tokens to the dataset.
    with TokenWriter(file_path, num_tokens) as writer:
        for _ in range(10):
            writer.write([1, 2, 3, 4, 5])

    # Read the tokens from the dataset.
    reader = TokenReader(file_path)
    num_samples = len(reader)
    for i in range(num_samples):
        print(reader[i])

You can also read some subset of the tokens in a line using slicing syntax.
This syntax will only read the required tokens from the file, rather than
reading the entire line and then slicing it. Here is an example:

.. highlight:: python
.. code-block:: python

    reader = TokenReader(file_path)
    print(reader[0])  # Prints the first line.
    print(reader[0, 1:3])  # Prints the first line, but only the second and third tokens.
"""

import functools
import logging
import math
import struct
from pathlib import Path
from types import TracebackType
from typing import BinaryIO, ContextManager, Iterable, Literal, cast, overload

from smart_open import open

logger = logging.getLogger(__name__)

NumberFormat = Literal["Q", "I", "H", "B"]

MAGIC = b"MLTK"  # Magic number for the token file format.
OFFSET_MAGIC = b"MLTO"  # Magic number for the offsets file format.


def _arr_to_bytes(tokens: Iterable[int], num_tokens: int, offset: int = 0) -> tuple[bytes, int]:
    assert 0 <= offset < 8
    num_bits = (num_tokens - 1).bit_length()
    byte_arr = bytearray()
    cur_token = 0
    cur_bits = 0
    total_len = 0
    for token in tokens:
        total_len += 1
        assert 0 <= token <= num_tokens
        cur_token += token << cur_bits
        cur_bits += num_bits
        if offset > 0:
            cur_token <<= offset
            cur_bits += offset
            offset = 0
        while cur_bits >= 8:
            byte_arr.append(cur_token & 0xFF)
            cur_token >>= 8
            cur_bits -= 8
    if cur_bits:
        byte_arr.append(cur_token)
    return bytes(byte_arr), total_len


def _bytes_to_arr(data: bytes, seq_len: int, num_tokens: int, offset: int = 0) -> list[int]:
    assert 0 <= offset < 8
    num_bits = (num_tokens - 1).bit_length()
    arr: list[int] = []
    cur_token = 0
    cur_bits = 0
    mask = (1 << num_bits) - 1
    for byte in data:
        cur_token += byte << cur_bits
        cur_bits += 8
        if offset != 0:
            cur_token >>= offset
            cur_bits -= offset
            offset = 0
        while cur_bits >= num_bits:
            arr.append(cur_token & mask)
            if len(arr) == seq_len:
                return arr
            cur_token >>= num_bits
            cur_bits -= num_bits
    raise ValueError("Not enough bytes to fill sequence")


[docs]class TokenWriter(ContextManager): """Helper class for writing a dataset of tokens to a file. This class can be used in conjunction with :class:`TokenReader` to write and read datasets of tokens. The default numerical formats are chosen to work well with typical ranges of token datasets. At the upper end, this supports ``2 ^ 32`` tokens, ``2 ^ 32`` tokens per line, and ``2 ^ 64`` tokens per file. Parameters: path: The path to the file to write to. num_tokens: The number of tokens in the dataset. overwrite_if_exists: Whether to overwrite the file if it already exists. num_tokens_fmt: The format string for the number of tokens. lengths_fmt: The format string for the lengths of each line. offset_fmt: The format string for the offsets of each line. """ def __init__( self, path: str | Path, num_tokens: int, overwrite_if_exists: bool = False, *, num_tokens_fmt: NumberFormat = "I", lengths_fmt: NumberFormat = "I", offset_fmt: NumberFormat = "Q", ) -> None: self._path = Path(path) self._fp: BinaryIO | None = None self._offsets: list[int] = [] self._offset_idx = -1 self._num_tokens = num_tokens self._overwrite_if_exists = overwrite_if_exists self._num_tokens_fmt = num_tokens_fmt self._lengths_fmt = lengths_fmt self._offset_fmt = offset_fmt def __enter__(self) -> "TokenWriter": if self._path.exists(): if self._overwrite_if_exists: logger.warning("Token file already exists and will be overwritten") else: raise FileExistsError(f"Token file already exists at {self._path}") self._fp = cast(BinaryIO, open(self._path, "wb")) self._offsets = [] # Writes the file magic. self._fp.write(MAGIC) # Writes the number formats which were used in this file. self._fp.write((self._num_tokens_fmt + self._lengths_fmt + self._offset_fmt).encode("ascii")) # Writes the number of unique tokens. self._fp.write(struct.pack(self._num_tokens_fmt, self._num_tokens)) # Writes a pointer to the start of the offsets table and the number # of offsets (i.e., the number of written rows). self._offset_idx = self._fp.tell() self._fp.write(struct.pack(f"2{self._offset_fmt}", 0, 0)) return self def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None: assert self._fp is not None assert self._offset_idx != -1 # Writes the offsets table. offsets_start = self._fp.tell() self._fp.write(struct.pack(f"{len(self._offsets)}{self._offset_fmt}", *self._offsets)) # Writes the pointer to the offsets table. self._fp.seek(self._offset_idx) self._fp.write(struct.pack(f"2{self._offset_fmt}", offsets_start, len(self._offsets))) self._fp.flush() self._fp.close()
[docs] def write(self, tokens: Iterable[int]) -> None: assert self._fp is not None, "TokenWriter must be opened with a context manager" # Converts the tokens to a binary array. byte_data, num_tokens = _arr_to_bytes(tokens, self._num_tokens) # Writes the binary data self._offsets.append(self._fp.tell()) self._fp.write(struct.pack(self._lengths_fmt, num_tokens)) self._fp.write(byte_data)
[docs] def writemany(self, tokens: Iterable[Iterable[int]]) -> None: assert self._fp is not None, "TokenWriter must be opened with a context manager" for line in tokens: self.write(line)
[docs] def flush(self) -> None: assert self._fp is not None, "TokenWriter must be opened with a context manager" self._fp.flush()
[docs]class TokenReader: """Helper class for reading a dataset of tokens from a file. This class can be used in conjunction with :class:`TokenWriter` to write and read datasets of tokens. Parameters: path: The path to the file to read from. shard: Read a specific shard from the dataset. """ def __init__(self, path: str | Path) -> None: self._path = Path(path) with open(self._path, "rb") as f: magic = f.read(len(MAGIC)) if magic != MAGIC: raise ValueError("Invalid token file") # Reads the number formats. fmt_strings = f.read(3).decode("ascii") self._num_tokens_fmt = fmt_strings[0] self._lengths_fmt = fmt_strings[1] self._offset_fmt = fmt_strings[2] # Reads the number of tokens. self._num_tokens = struct.unpack(self._num_tokens_fmt, f.read(struct.calcsize(self._num_tokens_fmt)))[0] # Reads the offset table start and length. offsets_vals = struct.unpack(f"2{self._offset_fmt}", f.read(struct.calcsize(self._offset_fmt) * 2)) self._offset_start, self._num_rows = offsets_vals self._lengths_fmt_size = struct.calcsize(self._lengths_fmt) def read_offsets() -> list[int]: offsets: list[int] = [] f.seek(self._offset_start) offset_bytes = f.read(struct.calcsize(self._offset_fmt) * self._num_rows) offsets.extend(struct.unpack(f"{self._num_rows}{self._offset_fmt}", offset_bytes)) offsets.append(self._offset_start) return offsets self._offsets = read_offsets() @functools.cached_property def bits_per_token(self) -> int: return math.ceil(math.log2(self._num_tokens))
[docs] def byte_length(self, index: int) -> int: start = self._offsets[index] end = self._offsets[index + 1] return end - start
[docs] def length(self, index: int) -> int: return ((self.byte_length(index) - self._lengths_fmt_size) * 8) // self.bits_per_token
@property def byte_lengths(self) -> list[int]: return [self.byte_length(i) for i in range(self._num_rows)] @property def lengths(self) -> list[int]: return [self.length(i) for i in range(self._num_rows)] @property def offsets(self) -> list[int]: return self._offsets def __len__(self) -> int: return self._num_rows @overload def __getitem__(self, index: int | tuple[int, slice]) -> list[int]: ... @overload def __getitem__(self, index: slice) -> list[list[int]]: ... def __getitem__(self, index: int | tuple[int, slice] | slice) -> list[int] | list[list[int]]: if isinstance(index, int): offset = self._offsets[index] seq_len = self.length(index) start, length = offset + self._lengths_fmt_size, (seq_len * self.bits_per_token + 7) // 8 with open(self._path, "rb") as f: f.seek(start) byte_data = f.read(length) return _bytes_to_arr(byte_data, seq_len, self._num_tokens) if isinstance(index, tuple) and len(index) == 2 and isinstance(index[0], int) and isinstance(index[1], slice): index, seq_slice = index offset = self._offsets[index] seq_len = self.length(index) offset_start = offset + self._lengths_fmt_size def make_positive(n: int) -> int: return min(n if n >= 0 else n + seq_len, seq_len) # Breaks down the slice into start, stop, and step. start = 0 if seq_slice.start is None else make_positive(seq_slice.start) stop = seq_len if seq_slice.stop is None else make_positive(seq_slice.stop) if stop <= start: return cast(list[int], []) start_bit = start * self.bits_per_token start_byte, start_offset = start_bit // 8, start_bit % 8 end_byte = (stop * self.bits_per_token + 7) // 8 with open(self._path, "rb") as f: f.seek(offset_start) f.seek(start_byte, 1) byte_data = f.read(end_byte - start_byte) arr = _bytes_to_arr(byte_data, stop - start, self._num_tokens, offset=start_offset) if seq_slice.step is not None: arr = arr[:: seq_slice.step] return arr if isinstance(index, slice): def make_positive(n: int) -> int: return min(n if n >= 0 else n + len(self), len(self)) start = 0 if index.start is None else make_positive(index.start) stop = len(self) if index.stop is None else make_positive(index.stop) if stop <= start: return cast(list[int], []) # Non-contiguous reads can just be done using existing logic. if index.step is not None and index.step != 1: return [self[i] for i in range(start, stop, index.step)] offsets = [self._offsets[i] for i in range(start, stop)] seq_lens = [self.length(i) for i in range(start, stop)] start = offsets[0] + self._lengths_fmt_size stop = offsets[-1] + self._lengths_fmt_size + (seq_lens[-1] * self.bits_per_token + 7) // 8 with open(self._path, "rb") as f: f.seek(start) byte_data = f.read(stop - start) starts = [offset - offsets[0] for offset in offsets] return [ _bytes_to_arr(byte_data[start:], seq_len, self._num_tokens) for start, seq_len in zip(starts, seq_lens) ] raise TypeError("Index must be an integer or a tuple of an integer and a slice")
[docs]class token_file: # noqa: N801
[docs] @classmethod def to_bytes(cls, tokens: Iterable[int], num_tokens: int) -> bytes: return _arr_to_bytes(tokens, num_tokens)[0]
[docs] @classmethod def from_bytes(cls, tokens_enc: bytes, seq_len: int, num_tokens: int) -> list[int]: return _bytes_to_arr(tokens_enc, seq_len, num_tokens)
@overload @classmethod def open( cls, path: str | Path, mode: Literal["w"], num_tokens: int, overwrite_if_exists: bool = False, ) -> TokenWriter: ... @overload @classmethod def open(cls, path: str | Path, mode: Literal["r"] = "r") -> TokenReader: ...
[docs] @classmethod def open( cls, path: str | Path, mode: Literal["r", "w"] = "r", num_tokens: int | None = None, overwrite_if_exists: bool = False, ) -> TokenReader | TokenWriter: """Opens a token file for reading or writing. Args: path: The path to the token file. mode: The mode to open the file in. Can be either ``"r"`` for reading or ``"w"`` for writing. num_tokens: The number of tokens in the dataset. Required when opening in write mode. overwrite_if_exists: Whether to overwrite the file if it already exists. Only used when opening in write mode. Returns: A :class:`TokenReader` or :class:`TokenWriter` depending on mode. """ match mode: case "r": return TokenReader(path) case "w": if num_tokens is None: raise ValueError("`num_tokens` is required when opening in write mode") return TokenWriter( path, num_tokens=num_tokens, overwrite_if_exists=overwrite_if_exists, ) case _: raise ValueError(f"Unexpected mode: {mode}")