in optimum/quanto/tensor/weights/qbits.py [0:0]
def __torch_dispatch__(cls, op, types, args, kwargs=None):
# Do not use directly op, but rather its overload
op = op.overloadpacket
if op is torch.ops.aten.detach:
t = args[0]
# Detach is required when copying and deserializing
inner_tensor_names, meta = t.__tensor_flatten__()
# Detach inner tensors
detached_tensors = {}
for inner_name in inner_tensor_names:
detached_tensors[inner_name] = op(getattr(t, inner_name))
return cls.__tensor_unflatten__(detached_tensors, meta, t.size(), t.stride())
elif op in [torch.ops.aten._to_copy, torch.ops.aten.to]:
t = args[0]
dtype = kwargs.pop("dtype", t.dtype)
device = kwargs.pop("device", t.device)
if dtype is not None and dtype != t.dtype:
raise ValueError("The dtype of a WeightQBitsTensor cannot be changed")
if type(t) is not WeightQBitsTensor and t.device.type != device.type:
# Before moving to another device type, convert back to a WeightQBitsTensor
t = t.weight_qbits_tensor()
scale = op(t._scale, dtype=dtype, device=device, **kwargs)
data = op(t._data, device=device, **kwargs)
shift = op(t._shift, device=device, **kwargs)
return WeightQBitsTensor.create(t._qtype, t._axis, t._group_size, t.size(), t.stride(), data, scale, shift)
# No dispatch available: qfallback
kwargs = kwargs or {}
return qfallback(op, *args, **kwargs)