def group()

in optimum/quanto/tensor/grouped.py [0:0]


def group(base: torch.Tensor, axis: int, group_size: int):
    if axis not in (0, -1):
        raise ValueError("Axis must be 0 or -1 for group-wise quantization")
    # In standard per-axis quantization, we have one scale per axis dim
    axis_dim = base.shape[axis]
    # This scale is evaluated over axis_numel items for each feature along axis
    axis_numel = base.numel() // axis_dim
    if group_size > axis_numel or axis_numel % group_size != 0:
        raise ValueError(f"Group size ({group_size}) must be a divisor of ({axis_numel})")
    # Group-wise quantization further splits axis_numel into multiple groups per axis
    axis_groups = axis_numel // group_size
    if axis == 0:
        # Easy-peasy: we simply need to reshape to (axis_dim * axis_groups, group_size)
        return base.reshape([-1, group_size])
    # More difficult: reshape to (group_size, axis_dim * axis_groups)
    # First, split by groups, preserving the axis dimension
    grouped = base.reshape((axis_groups, group_size, axis_dim))
    # Permute to (group_size, axis_dim, axis_groups)
    grouped = grouped.permute(1, 2, 0)
    return grouped.reshape(group_size, axis_dim * axis_groups)