def ungroup()

in optimum/quanto/tensor/grouped.py [0:0]


def ungroup(grouped: torch.Tensor, axis: int, orig_shape: torch.Size):
    if grouped.shape == orig_shape:
        return grouped
    if axis == 0:
        # No transposition required, just reshape
        return grouped.reshape(orig_shape)
    group_size = grouped.shape[0] if axis == -1 else grouped.shape[-1]
    axis_dim = orig_shape[axis]
    axis_groups = grouped.numel() // axis_dim // group_size
    ungrouped = grouped.reshape(group_size, axis_dim, axis_groups)
    # Permute to (axis_groups, group_size, axis_dim)
    ungrouped = ungrouped.permute(2, 0, 1)
    return ungrouped.reshape(orig_shape)