ml.tasks.datasets.collate

Defines custom collation functions for PyTorch datasets.

ml.tasks.datasets.collate.pil_to_tensor(pic: Image) Tensor[source]
ml.tasks.datasets.collate.is_named_tuple(obj: Any) bool[source]
ml.tasks.datasets.collate.pad_sequence(tensors: list[torch.Tensor], *, dim: int = 0, max_length: int | None = None, left_pad: bool = False, left_truncate: bool = False, pad_value: int | float | bool = 0) list[torch.Tensor][source]

Pads or truncates a sequence of tensors to the same length.

Parameters:
  • tensors – The tensors to pad or truncate

  • dim – The dimension to pad or truncate

  • max_length – The maximum tensor length

  • left_pad – If set, pad on the left side, otherwise pad the right side

  • left_truncate – If set, truncate on the left side, otherwise truncate on the right side

  • pad_value – The padding value to use

Returns:

The padded or truncated tensors

Raises:

ValueError – If the tensor dimensions are invalid

ml.tasks.datasets.collate.pad_all(tensors: list[torch.Tensor], *, max_length: int | None = None, left_pad: bool = False, left_truncate: bool = False, pad_value: int | float | bool = 0) list[torch.Tensor][source]

Pads all tensors to the same shape.

Parameters:
  • tensors – The tensors to pad

  • max_length – The maximum tensor length

  • left_pad – If set, pad on the left side, otherwise pad the right side

  • left_truncate – If set, truncate on the left side, otherwise truncate on the right side

  • pad_value – The padding value to use

Returns:

The padded tensors

ml.tasks.datasets.collate.collate(items: list[Any], *, mode: Literal['stack', 'concat'] | Callable[[list[torch.Tensor]], Tensor] = 'stack', pad: bool | Callable[[list[torch.Tensor]], list[torch.Tensor]] = False) Any | None[source]

Defines a general-purpose collating function.

Parameters:
  • items – The list of items to collate

  • mode – Either stack, concat, or a custom function which is called on a list of tensors and returns a single tensor

  • pad – If set to True, pads sequences using the default padding function. Can also pass a function which will perform padding

Returns:

The collated item, or None if the item list was empty

Raises:

NotImplementedError – If the mode is invalid

ml.tasks.datasets.collate.collate_non_null(items: list[Any], *, mode: Literal['stack', 'concat'] | Callable[[list[torch.Tensor]], Tensor] = 'stack', pad: bool | Callable[[list[torch.Tensor]], list[torch.Tensor]] = False) Any[source]