in optimum/quanto/tensor/weights/tinygemm/packed.py [0:0]
def __torch_dispatch__(cls, op, types, args, kwargs=None):
# Convert back to tensor before calling any operation except detach and move
if op.overloadpacket is torch.ops.aten.detach:
t = args[0]
data = op(t._data)
return TinyGemmPackedTensor(data, t.size(), t.stride())
elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to):
t = args[0]
dtype = kwargs.get("dtype", torch.uint8)
if dtype != torch.uint8:
raise ValueError(f"TinyGemmPackedTensor are torch.uint8 only and cannot be moved to {dtype}.")
data_kwargs = copy(kwargs)
data_kwargs["dtype"] = t._data.dtype
if kwargs.get("device", t.device).type != t.device.type:
# Packing is device specific, so we need to unpack before moving
unpacked = t.unpack()
unpacked = op(unpacked, **data_kwargs)
return TinyGemmPackedTensor.pack(unpacked)
# If we stay on the same device type, just copy/move packed data
data = op(t._data, **data_kwargs)
return TinyGemmPackedTensor(data, t.size(), t.stride())
args, kwargs = pytree.tree_map_only(TinyGemmPackedTensor, lambda x: x.unpack(), (args, kwargs or {}))
return op(*args, **kwargs)