in optimum/quanto/tensor/activations/qbytes_ops.py [0:0]
def cat(op, inputs, dim=0):
if len(inputs) == 2:
t1, t2 = inputs
# Only quantized tensors with identical scalar scales can be concatenated
if (
isinstance(t1, ActivationQBytesTensor)
and isinstance(t2, ActivationQBytesTensor)
and torch.equal(t1._scale, t2._scale)
and t1.qtype == t2.qtype
):
if t1.qtype.is_floating_point or t2.qtype.is_floating_point:
# Cat is not supported for float8
return qfallback(op, inputs, dim)
out_data = op([t1._data, t2._data], dim)
return ActivationQBytesTensor(t1.qtype, out_data.size(), out_data.stride(), out_data, t1._scale)
return qfallback(op, inputs, dim)