in optimum/quanto/tensor/weights/marlin/fp8/packed.py [0:0]
def __torch_dispatch__(cls, op, types, args, kwargs=None):
# Convert back to tensor before calling any operation except detach and move
if op.overloadpacket is torch.ops.aten.detach:
t = args[0]
data = op(t._data)
return cls(data, t.size(), t.stride())
elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to):
t = args[0]
dtype = kwargs.get("dtype", torch.int32)
if dtype != torch.int32:
raise ValueError(f"MarlinF8PackedTensor are torch.int32 only and cannot be moved to {dtype}.")
device = kwargs.get("device", t.device)
if device.type == "cuda":
data_kwargs = copy(kwargs)
data_kwargs["dtype"] = t._data.dtype
data = op(t._data, **data_kwargs)
return cls(data, t.size(), t.stride())
else:
return t.unpack().to(device)
else:
args, kwargs = pytree.tree_map_only(cls, lambda x: x.unpack(), (args, kwargs or {}))
return op(*args, **kwargs)