in optimum/quanto/tensor/weights/marlin/int4/packed.py [0:0]
def __torch_dispatch__(cls, op, types, args, kwargs=None):
if op.overloadpacket is torch.ops.aten.detach:
t = args[0]
data = op(t._data)
return MarlinInt4PackedTensor(data, t.size(), t.stride())
elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to):
t = args[0]
dtype = kwargs.get("dtype", torch.uint8)
if dtype != torch.uint8:
raise ValueError(f"MarlinInt4PackedTensor are torch.uint8 only and cannot be moved to {dtype}.")
device = kwargs.get("device", t.device)
if device.type == "cuda":
data_kwargs = copy(kwargs)
data_kwargs["dtype"] = t._data.dtype
data = op(t._data, **data_kwargs)
return MarlinInt4PackedTensor(data, t.size(), t.stride())
return t.unpack()
args, kwargs = pytree.tree_map_only(MarlinInt4PackedTensor, lambda x: x.unpack(), (args, kwargs or {}))
return op(*args, **kwargs)