def cat()

in optimum/quanto/tensor/activations/qbytes_ops.py [0:0]


def cat(op, inputs, dim=0):
    if len(inputs) == 2:
        t1, t2 = inputs
        # Only quantized tensors with identical scalar scales can be concatenated
        if (
            isinstance(t1, ActivationQBytesTensor)
            and isinstance(t2, ActivationQBytesTensor)
            and torch.equal(t1._scale, t2._scale)
            and t1.qtype == t2.qtype
        ):
            if t1.qtype.is_floating_point or t2.qtype.is_floating_point:
                # Cat is not supported for float8
                return qfallback(op, inputs, dim)
            out_data = op([t1._data, t2._data], dim)
            return ActivationQBytesTensor(t1.qtype, out_data.size(), out_data.stride(), out_data, t1._scale)
    return qfallback(op, inputs, dim)