in optimum/quanto/tensor/grouped.py [0:0]
def ungroup(grouped: torch.Tensor, axis: int, orig_shape: torch.Size):
if grouped.shape == orig_shape:
return grouped
if axis == 0:
# No transposition required, just reshape
return grouped.reshape(orig_shape)
group_size = grouped.shape[0] if axis == -1 else grouped.shape[-1]
axis_dim = orig_shape[axis]
axis_groups = grouped.numel() // axis_dim // group_size
ungrouped = grouped.reshape(group_size, axis_dim, axis_groups)
# Permute to (axis_groups, group_size, axis_dim)
ungrouped = ungrouped.permute(2, 0, 1)
return ungrouped.reshape(orig_shape)