def maybe_pad_interleaved()

in optimum/neuron/models/inference/backend/modules/attention/gqa.py [0:0]


def maybe_pad_interleaved(tensor, pad_dim: int, source_heads: int, target_heads: int, source_group_size: int):
    if tensor is None:
        return tensor

    # Why we convert FP8 tensor to bfloat16?
    # Torch does not support torch.cat, or torch.zeros (for large dimensions) for f8e4m3/f8e5m2
    # So we cast it to bfloat16, perform padding, and then recast back to f8e4m3/f8e5m2
    recast_dtype = None
    if tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
        recast_dtype = tensor.dtype
        tensor = tensor.to(torch.bfloat16)

    shape = (
        tensor.shape[:pad_dim] + (source_heads, tensor.shape[pad_dim] // source_heads) + tensor.shape[pad_dim + 1 :]
    )
    tensor = tensor.view(shape)

    splits = torch.split(tensor, source_group_size, dim=pad_dim)

    pad_size = list(splits[0].size())
    pad_size[pad_dim] = (target_heads - source_heads) // (source_heads // source_group_size)
    pads = [torch.zeros(pad_size, dtype=tensor.dtype)] * len(splits)

    interleaved = [t for pair in zip(splits, pads) for t in pair]
    tensor = torch.cat(interleaved, dim=pad_dim)

    shape = tensor.shape[:pad_dim] + (tensor.shape[pad_dim] * tensor.shape[pad_dim + 1],) + tensor.shape[pad_dim + 2 :]

    if recast_dtype is not None:
        tensor = tensor.to(recast_dtype)

    return tensor.view(shape)