ml.tasks.datasets.collate
Defines custom collation functions for PyTorch datasets.
- ml.tasks.datasets.collate.pad_sequence(tensors: list[torch.Tensor], *, dim: int = 0, max_length: int | None = None, left_pad: bool = False, left_truncate: bool = False, pad_value: int | float | bool = 0) list[torch.Tensor] [source]
Pads or truncates a sequence of tensors to the same length.
- Parameters:
tensors – The tensors to pad or truncate
dim – The dimension to pad or truncate
max_length – The maximum tensor length
left_pad – If set, pad on the left side, otherwise pad the right side
left_truncate – If set, truncate on the left side, otherwise truncate on the right side
pad_value – The padding value to use
- Returns:
The padded or truncated tensors
- Raises:
ValueError – If the tensor dimensions are invalid
- ml.tasks.datasets.collate.pad_all(tensors: list[torch.Tensor], *, max_length: int | None = None, left_pad: bool = False, left_truncate: bool = False, pad_value: int | float | bool = 0) list[torch.Tensor] [source]
Pads all tensors to the same shape.
- Parameters:
tensors – The tensors to pad
max_length – The maximum tensor length
left_pad – If set, pad on the left side, otherwise pad the right side
left_truncate – If set, truncate on the left side, otherwise truncate on the right side
pad_value – The padding value to use
- Returns:
The padded tensors
- ml.tasks.datasets.collate.collate(items: list[Any], *, mode: Literal['stack', 'concat'] | Callable[[list[torch.Tensor]], Tensor] = 'stack', pad: bool | Callable[[list[torch.Tensor]], list[torch.Tensor]] = False) Any | None [source]
Defines a general-purpose collating function.
- Parameters:
items – The list of items to collate
mode – Either stack, concat, or a custom function which is called on a list of tensors and returns a single tensor
pad – If set to True, pads sequences using the default padding function. Can also pass a function which will perform padding
- Returns:
The collated item, or None if the item list was empty
- Raises:
NotImplementedError – If the mode is invalid