ml.utils.data

Some common utilities for datasets and data loaders.

class ml.utils.data.WorkerInfo(worker_id: int, num_workers: int, in_worker: bool)[source]

Bases: object

worker_id: int
num_workers: int
in_worker: bool
ml.utils.data.get_worker_info() WorkerInfo[source]

Gets a typed worker info object which always returns a value.

Returns:

The typed worker info object

ml.utils.data.split_n_items_across_workers(n: int, worker_id: int, num_workers: int) tuple[int, int][source]

Splits N items across workers.

This returns the start and end indices for the items to be processed by the given worker. The end index is exclusive.

Parameters:
  • n – The number of items to process.

  • worker_id – The ID of the current worker.

  • num_workers – The total number of workers.

ml.utils.data.get_dataset_splits(items: Sequence[T], valid: float | int, test: float | int) tuple[Sequence[T], Sequence[T], Sequence[T]][source]

Splits a list of items into three sub-lists for train, valid, and test.

Parameters:
  • items – The list of items to split.

  • valid – If a value between 0 and 1, the fraction of items to use for the validation set, otherwise the number of items to use for the validation set.

  • test – If a value between 0 and 1, the fraction of items to use for the test set, otherwise the number of items to use for the test set.

Returns:

A tuple of three lists, one for each phase.

Raises:

ValueError – If the split sizes would be invalid.

ml.utils.data.get_dataset_split_for_phase(items: Sequence[T], phase: Literal['train', 'valid', 'test'], valid: float | int, test: float | int) Sequence[T][source]

Gets the items for a given phase.

Parameters:
  • items – The list of items to split.

  • phase – The phase to get the items for.

  • valid – If a value between 0 and 1, the fraction of items to use for the validation set, otherwise the number of items to use for the validation set.

  • test – If a value between 0 and 1, the fraction of items to use for the test set, otherwise the number of items to use for the test set.

Returns:

The items for the given phase.

Raises:

ValueError – If the phase is not valid.

ml.utils.data.check_md5(file_path: str | Path, hash_str: str | None, chunk_size: int = 65536) bool[source]

Checks the MD5 of the downloaded file.

Parameters:
  • file_path – Path to the downloaded file.

  • hash_str – Expected MD5 of the file; if None, return True.

  • chunk_size – Size of the chunks to read from the file.

Returns:

True if the MD5 matches, False otherwise.

ml.utils.data.check_sha256(file_path: str | Path, hash_str: str | None, chunk_size: int = 65536) bool[source]

Checks the SHA256 of the downloaded file.

Parameters:
  • file_path – Path to the downloaded file.

  • hash_str – Expected SHA256 of the file; if None, return True.

  • chunk_size – Size of the chunks to read from the file.

Returns:

True if the SHA256 matches, False otherwise.

class ml.utils.data.Header(files: list[tuple[str, int]], init_offset: int = 0)[source]

Bases: object

files: list[tuple[str, int]]
init_offset: int = 0
encode() bytes[source]
write(fp: IO[bytes]) None[source]
classmethod decode(b: bytes) Header[source]
classmethod read(fp: IO[bytes]) tuple['Header', int][source]
shard(shard_id: int, total_shards: int) Header[source]
offsets(header_size: int) list[int][source]
ml.utils.data.compress_folder_to_sds(input_dir: str | Path, output_path: str | Path, only_extensions: Collection[str] | None = None, exclude_extensions: Collection[str] | None = None) None[source]

Compresses a given folder to a streamable dataset (SDS).

Parameters:
  • input_dir – The directory to compress.

  • output_path – The root directory to write the shards to.

  • only_extensions – If not None, only files with these extensions will be included.

  • exclude_extensions – If not None, files with these extensions will be excluded.

class ml.utils.data.SdsDataPipe(path: str | Path)[source]

Bases: MapDataPipe[tuple[str, int, BinaryIO]]

Defines a base reader for streamable datasets.

This used to incorporate more functionality, but I’ve since migrated to using smart_open which handles the various backends, so now the data format is basically just a TAR file with a more efficient header for random access.

Parameters:
  • shard_id – The index of the current reader shard. If not specified, will default to the current rank.

  • total_shards – The total number of reader shards. If not specified, will default to the world size.

get_header_and_offsets() tuple[ml.utils.data.Header, int][source]
read(start: int, length: int) bytes[source]
ml.utils.data.upload_data_to_s3(file_path: str | Path, prefix: str | None = None, name: str | None = None, bucket: str | None = None) None[source]

Uploads a data file to S3.

Parameters:
  • file_path – The path to the file to upload.

  • prefix – The prefix to use for the uploaded file, if requested.

  • name – The name to use for the uploaded file. If not specified, will default to the name of the file.

  • bucket – The bucket to upload to. If not specified, will default to the bucket specified by get_s3_data_bucket.