in optimum/quanto/tensor/weights/marlin/int4/qbits.py [0:0]
def forward(ctx, input, other, bias):
ctx.save_for_backward(input, other)
if type(input) is not torch.Tensor:
input = input.dequantize()
out_features, in_features = other.shape
output = torch.ops.quanto.gemm_f16i4_marlin(
input,
other._data._data,
other._scale,
other._shift,
other._workspace,
)
if bias is not None:
output = output + bias
return output