in optimum/quanto/tensor/weights/tinygemm/qbits.py [0:0]
def __init__(self, qtype, axis, group_size, size, stride, data, scale_shift, requires_grad=False):
assert axis == 0
if not isinstance(data, TinyGemmPackedTensor):
assert type(data) is torch.Tensor
assert isinstance(scale_shift, (tuple, list))
# Format data, scale and shift for tinygemm
ungrouped = ungroup(data, axis=0, orig_shape=size)
self._data = TinyGemmPackedTensor.pack(ungrouped)
out_features, in_features = size
scale, shift = scale_shift
scale = scale.reshape(out_features, in_features // group_size, 1)
shift = shift.reshape(out_features, in_features // group_size, 1)
if not shift.dtype.is_floating_point:
# Integer shift must be scaled
shift = scale * shift
# The tinygemm kernel actually uses the mid-point of the quantization range as shift
min_range = -shift
half_qrange = 2 ** (qtype.bits - 1) * scale
# This operation is lossy for bfloat16, and the actual value of shift will be lost
shift = min_range + half_qrange
# Scale and shift are actually stored in the same tensor
self._scale_shift = torch.cat([scale, shift], 2).transpose(0, 1).contiguous()
else:
self._data = data
self._scale_shift = scale_shift
self._qtype = qtype
self._axis = axis
self._group_size = group_size