def bmm()

in optimum/quanto/tensor/activations/qbytes_ops.py [0:0]


def bmm(op, input, other):
    if not isinstance(input, ActivationQBytesTensor):
        return op(input, other.dequantize())
    if not isinstance(other, QTensor) or input.axis is not None:
        return op(input.dequantize(), other)
    if input.qtype != qint8 or other.qtype != qint8 or cannot_mm(other):
        return qfallback(op, input, other)
    # Cast data to float32 and do the operation
    out_data = op(input._data.to(torch.float32), other._data.to(torch.float32))
    out_scale = (input._scale * other._scale).to(torch.float32)
    return (out_data * out_scale).to(input._scale.dtype)