def onnx_compatible_unfold()

in optimum/exporters/onnx/model_patcher.py [0:0]


def onnx_compatible_unfold(input_tensor, dimension, size, step):
    """
    Custom implementation of torch.unfold without using torch.unfold.

    Args:
        input_tensor (torch.Tensor): The input tensor.
        dimension (int): The dimension to unfold.
        size (int): The size of each slice.
        step (int): The step size between slices.

    Returns:
        torch.Tensor: The unfolded tensor.
    """
    # Check if dimension is within the valid range
    if not (-input_tensor.dim() <= dimension < input_tensor.dim()):
        raise ValueError(
            f"Dimension out of range (expected to be in range of [{-input_tensor.dim()}, {input_tensor.dim() - 1}], but got {dimension})"
        )

    # Normalize negative dimension
    dimension = dimension % input_tensor.dim()

    # Compute the shape of the unfolded output
    input_size = input_tensor.size(dimension)
    num_slices = (input_size - size) // step + 1

    # Permute dimension to the end for easier indexing
    input_tensor = input_tensor.transpose(dimension, -1)

    # Extract slices
    slices = []
    for i in range(num_slices):
        start = i * step
        end = start + size
        slices.append(input_tensor[..., start:end])

    # Stack slices and permute dimensions back
    result = torch.stack(slices, dim=-2).transpose(dimension, -2)
    return result