in optimum/quanto/tensor/grouped.py [0:0]
def group(base: torch.Tensor, axis: int, group_size: int):
if axis not in (0, -1):
raise ValueError("Axis must be 0 or -1 for group-wise quantization")
# In standard per-axis quantization, we have one scale per axis dim
axis_dim = base.shape[axis]
# This scale is evaluated over axis_numel items for each feature along axis
axis_numel = base.numel() // axis_dim
if group_size > axis_numel or axis_numel % group_size != 0:
raise ValueError(f"Group size ({group_size}) must be a divisor of ({axis_numel})")
# Group-wise quantization further splits axis_numel into multiple groups per axis
axis_groups = axis_numel // group_size
if axis == 0:
# Easy-peasy: we simply need to reshape to (axis_dim * axis_groups, group_size)
return base.reshape([-1, group_size])
# More difficult: reshape to (group_size, axis_dim * axis_groups)
# First, split by groups, preserving the axis dimension
grouped = base.reshape((axis_groups, group_size, axis_dim))
# Permute to (group_size, axis_dim, axis_groups)
grouped = grouped.permute(1, 2, 0)
return grouped.reshape(group_size, axis_dim * axis_groups)