in optimum/quanto/tensor/packed.py [0:0]
def __torch_dispatch__(cls, op, types, args, kwargs=None):
# Convert back to tensor before calling any operation except detach
if op.overloadpacket is torch.ops.aten.detach:
t = args[0]
data = op(t._data)
return PackedTensor(data, t._bits, t.size(), t.stride())
elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to):
t = args[0]
dtype = kwargs.get("dtype", torch.uint8)
if dtype != torch.uint8:
raise ValueError(f"PackedTensor are torch.uint8 only and cannot be moved to {dtype}.")
# Move data
data = op(t._data, **kwargs)
return PackedTensor(data, t._bits, t.size(), t.stride())
args, kwargs = pytree.tree_map_only(PackedTensor, lambda x: x.unpack(), (args, kwargs or {}))
return op(*args, **kwargs)