def onnx_compatible_repeat_interleave()

in optimum/exporters/onnx/model_patcher.py [0:0]


def onnx_compatible_repeat_interleave(input_tensor, repeats, dim=None, output_size=None):
    """
    Custom implementation of torch.repeat_interleave without using torch.repeat_interleave.

    Args:
        input_tensor (torch.Tensor): The input tensor.
        repeats (int or torch.Tensor): The number of repetitions for each element.
        dim (int, optional): The dimension along which to repeat. Defaults to None.

    Returns:
        torch.Tensor: The repeated tensor.
    """
    if isinstance(repeats, int) or (torch.is_tensor(repeats) and repeats.dim() == 0):
        if dim is None:
            return input_tensor.flatten().unsqueeze(1).expand(-1, repeats).flatten()
        repeats = torch.full((input_tensor.shape[dim],), repeats, dtype=torch.long, device=input_tensor.device)

    if dim is None:
        return onnx_compatible_repeat_interleave(input_tensor.flatten(), repeats, 0)

    if dim != 0:
        input_tensor = input_tensor.transpose(0, dim)

    # Create expand mask
    max_repeats = repeats.max()
    expanded = input_tensor.unsqueeze(1).expand(-1, max_repeats, *input_tensor.shape[1:])
    mask = torch.arange(max_repeats, device=input_tensor.device) < repeats.unsqueeze(1)
    result = expanded[mask]

    if dim != 0:
        result = result.transpose(0, dim)

    return result