in optimum/exporters/onnx/model_patcher.py [0:0]
def onnx_compatible_repeat_interleave(input_tensor, repeats, dim=None, output_size=None):
"""
Custom implementation of torch.repeat_interleave without using torch.repeat_interleave.
Args:
input_tensor (torch.Tensor): The input tensor.
repeats (int or torch.Tensor): The number of repetitions for each element.
dim (int, optional): The dimension along which to repeat. Defaults to None.
Returns:
torch.Tensor: The repeated tensor.
"""
if isinstance(repeats, int) or (torch.is_tensor(repeats) and repeats.dim() == 0):
if dim is None:
return input_tensor.flatten().unsqueeze(1).expand(-1, repeats).flatten()
repeats = torch.full((input_tensor.shape[dim],), repeats, dtype=torch.long, device=input_tensor.device)
if dim is None:
return onnx_compatible_repeat_interleave(input_tensor.flatten(), repeats, 0)
if dim != 0:
input_tensor = input_tensor.transpose(0, dim)
# Create expand mask
max_repeats = repeats.max()
expanded = input_tensor.unsqueeze(1).expand(-1, max_repeats, *input_tensor.shape[1:])
mask = torch.arange(max_repeats, device=input_tensor.device) < repeats.unsqueeze(1)
result = expanded[mask]
if dim != 0:
result = result.transpose(0, dim)
return result