in optimum/quanto/tensor/weights/awq/qbits.py [0:0]
def __init__(self, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):
assert axis == 0
if not isinstance(data, AWQPackedTensor):
assert type(data) is torch.Tensor
# Format data, scale and shift for optimized CUDA gemm
ungrouped = ungroup(data, axis=0, orig_shape=size)
data = AWQPackedTensor.pack(ungrouped, packing=AWQPacking.V2)
out_features, in_features = size
scale = scale.reshape(out_features, in_features // group_size).t().contiguous()
shift = shift.reshape(out_features, in_features // group_size).t()
if not shift.dtype.is_floating_point:
# Integer shift must be scaled
shift = scale * shift
# Shift must be negated
shift = -shift.contiguous()
super().__init__(qtype, axis, group_size, size, stride, data, scale, shift)