in optimum/quanto/tensor/weights/awq/qbits.py [0:0]
def forward(ctx, input, other, bias):
ctx.save_for_backward(input, other)
if type(input) is not torch.Tensor:
input = input.dequantize()
out_features, in_features = other.shape
rows = input.numel() // in_features
output = torch.ops.quanto.gemm_f16i4_awq(
input,
other._data._data,
other._scale,
other._shift,
rows=rows,
out_cols=out_features,
in_cols=in_features,
bits=4,
group_size=other._group_size,
)
if bias is not None:
output = output + bias
return output