optimum/quanto/tensor/grouped.py (33 lines of code) (raw):

import math from typing import List import torch __all__ = ["group", "ungroup", "grouped_shape"] def grouped_shape(shape: List, axis: int, group_size: int) -> List: if axis not in (0, -1): raise ValueError("Axis must be 0 or -1 for group-wise quantization") n_groups = math.prod(shape) // group_size return (n_groups, group_size) if axis == 0 else (group_size, n_groups) def group(base: torch.Tensor, axis: int, group_size: int): if axis not in (0, -1): raise ValueError("Axis must be 0 or -1 for group-wise quantization") # In standard per-axis quantization, we have one scale per axis dim axis_dim = base.shape[axis] # This scale is evaluated over axis_numel items for each feature along axis axis_numel = base.numel() // axis_dim if group_size > axis_numel or axis_numel % group_size != 0: raise ValueError(f"Group size ({group_size}) must be a divisor of ({axis_numel})") # Group-wise quantization further splits axis_numel into multiple groups per axis axis_groups = axis_numel // group_size if axis == 0: # Easy-peasy: we simply need to reshape to (axis_dim * axis_groups, group_size) return base.reshape([-1, group_size]) # More difficult: reshape to (group_size, axis_dim * axis_groups) # First, split by groups, preserving the axis dimension grouped = base.reshape((axis_groups, group_size, axis_dim)) # Permute to (group_size, axis_dim, axis_groups) grouped = grouped.permute(1, 2, 0) return grouped.reshape(group_size, axis_dim * axis_groups) def ungroup(grouped: torch.Tensor, axis: int, orig_shape: torch.Size): if grouped.shape == orig_shape: return grouped if axis == 0: # No transposition required, just reshape return grouped.reshape(orig_shape) group_size = grouped.shape[0] if axis == -1 else grouped.shape[-1] axis_dim = orig_shape[axis] axis_groups = grouped.numel() // axis_dim // group_size ungrouped = grouped.reshape(group_size, axis_dim, axis_groups) # Permute to (axis_groups, group_size, axis_dim) ungrouped = ungrouped.permute(2, 0, 1) return ungrouped.reshape(orig_shape)