def __torch_dispatch__()

in optimum/quanto/tensor/weights/qbits.py [0:0]


    def __torch_dispatch__(cls, op, types, args, kwargs=None):
        # Do not use directly op, but rather its overload
        op = op.overloadpacket
        if op is torch.ops.aten.detach:
            t = args[0]
            # Detach is required when copying and deserializing
            inner_tensor_names, meta = t.__tensor_flatten__()
            # Detach inner tensors
            detached_tensors = {}
            for inner_name in inner_tensor_names:
                detached_tensors[inner_name] = op(getattr(t, inner_name))
            return cls.__tensor_unflatten__(detached_tensors, meta, t.size(), t.stride())
        elif op in [torch.ops.aten._to_copy, torch.ops.aten.to]:
            t = args[0]
            dtype = kwargs.pop("dtype", t.dtype)
            device = kwargs.pop("device", t.device)
            if dtype is not None and dtype != t.dtype:
                raise ValueError("The dtype of a WeightQBitsTensor cannot be changed")
            if type(t) is not WeightQBitsTensor and t.device.type != device.type:
                # Before moving to another device type, convert back to a WeightQBitsTensor
                t = t.weight_qbits_tensor()
            scale = op(t._scale, dtype=dtype, device=device, **kwargs)
            data = op(t._data, device=device, **kwargs)
            shift = op(t._shift, device=device, **kwargs)
            return WeightQBitsTensor.create(t._qtype, t._axis, t._group_size, t.size(), t.stride(), data, scale, shift)
        # No dispatch available: qfallback
        kwargs = kwargs or {}
        return qfallback(op, *args, **kwargs)