in optimum/quanto/tensor/weights/marlin/int4/qbits.py [0:0]
def __init__(self, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):
assert axis == 0
out_features, in_features = size
if not isinstance(data, MarlinInt4PackedTensor):
assert type(data) is torch.Tensor
# Format data, scale and shift for optimized CUDA gemm
ungrouped = ungroup(data, axis=0, orig_shape=size)
data = MarlinInt4PackedTensor.pack(ungrouped)
scale = scale.reshape(out_features, in_features // group_size).t().contiguous()
shift = shift.reshape(out_features, in_features // group_size).t()
if not shift.dtype.is_floating_point:
# Integer shift must be scaled
shift = scale * shift
# Shift must be negated
shift = -shift.contiguous()
# Finally, apply scale and shift permutations
scale = marlin_permute(scale)
shift = marlin_permute(shift)
super().__init__(qtype, axis, group_size, size, stride, data, scale, shift)
self._workspace = torch.zeros(out_features // 128 * 16, dtype=torch.int, device=data.device)