def __torch_dispatch__()

in optimum/quanto/tensor/weights/qbytes.py [0:0]


    def __torch_dispatch__(cls, op, types, args, kwargs=None):
        # Do not use directly op, but rather its overload
        op = op.overloadpacket
        if op is torch.ops.aten.detach:
            t = args[0]
            # Detach is required when copying and deserializing
            inner_tensor_names, meta = t.__tensor_flatten__()
            # Detach inner tensors
            detached_tensors = {}
            for inner_name in inner_tensor_names:
                detached_tensors[inner_name] = op(getattr(t, inner_name))
            return cls.__tensor_unflatten__(detached_tensors, meta, t.size(), t.stride())
        elif op in [torch.ops.aten._to_copy, torch.ops.aten.to]:
            t = args[0]
            dtype = kwargs.pop("dtype", t.dtype)
            device = kwargs.pop("device", t.device)
            if dtype != t.dtype:
                raise ValueError("The dtype of a weights Tensor cannot be changed")
            if type(t) is not WeightQBytesTensor and t.device.type != device.type:
                # Before moving to another device type, convert back to a WeightQBytesTensor
                t = t.weight_qbytes_tensor()
            out_data = op(t._data, device=device, **kwargs)
            out_scale = op(t._scale, device=device, **kwargs)
            return WeightQBytesTensor.create(
                t.qtype,
                t.axis,
                t.size(),
                t.stride(),
                out_data,
                out_scale,
                activation_qtype=t.activation_qtype,
                requires_grad=t.requires_grad,
            )
        elif op is torch.ops.aten.t and cls is WeightQBytesTensor:
            t = args[0]
            out_data = op(t._data)
            out_scale = t._scale
            out_axis = t.axis
            # Manually reverse size and stride because we cannot trust the out_data shape
            dim0, dim1 = t.size()
            out_size = torch.Size([dim1, dim0])
            out_stride = t.stride()[::-1]
            if t.axis is not None:
                # We need to transpose also the scale
                out_scale = op(out_scale)
                out_axis = 0 if out_axis == -1 else -1
            return WeightQBytesTensor(t.qtype, out_axis, out_size, out_stride, out_data, out_scale, t.activation_qtype)
        # No dispatch available: qfallback
        kwargs = kwargs or {}
        return qfallback(op, *args, **kwargs)