in optimum/quanto/tensor/activations/qbytes_ops.py [0:0]
def bmm(op, input, other):
if not isinstance(input, ActivationQBytesTensor):
return op(input, other.dequantize())
if not isinstance(other, QTensor) or input.axis is not None:
return op(input.dequantize(), other)
if input.qtype != qint8 or other.qtype != qint8 or cannot_mm(other):
return qfallback(op, input, other)
# Cast data to float32 and do the operation
out_data = op(input._data.to(torch.float32), other._data.to(torch.float32))
out_scale = (input._scale * other._scale).to(torch.float32)
return (out_data * out_scale).to(input._scale.dtype)