in optimum/quanto/tensor/weights/qbytes.py [0:0]
def __torch_dispatch__(cls, op, types, args, kwargs=None):
# Do not use directly op, but rather its overload
op = op.overloadpacket
if op is torch.ops.aten.detach:
t = args[0]
# Detach is required when copying and deserializing
inner_tensor_names, meta = t.__tensor_flatten__()
# Detach inner tensors
detached_tensors = {}
for inner_name in inner_tensor_names:
detached_tensors[inner_name] = op(getattr(t, inner_name))
return cls.__tensor_unflatten__(detached_tensors, meta, t.size(), t.stride())
elif op in [torch.ops.aten._to_copy, torch.ops.aten.to]:
t = args[0]
dtype = kwargs.pop("dtype", t.dtype)
device = kwargs.pop("device", t.device)
if dtype != t.dtype:
raise ValueError("The dtype of a weights Tensor cannot be changed")
if type(t) is not WeightQBytesTensor and t.device.type != device.type:
# Before moving to another device type, convert back to a WeightQBytesTensor
t = t.weight_qbytes_tensor()
out_data = op(t._data, device=device, **kwargs)
out_scale = op(t._scale, device=device, **kwargs)
return WeightQBytesTensor.create(
t.qtype,
t.axis,
t.size(),
t.stride(),
out_data,
out_scale,
activation_qtype=t.activation_qtype,
requires_grad=t.requires_grad,
)
elif op is torch.ops.aten.t and cls is WeightQBytesTensor:
t = args[0]
out_data = op(t._data)
out_scale = t._scale
out_axis = t.axis
# Manually reverse size and stride because we cannot trust the out_data shape
dim0, dim1 = t.size()
out_size = torch.Size([dim1, dim0])
out_stride = t.stride()[::-1]
if t.axis is not None:
# We need to transpose also the scale
out_scale = op(out_scale)
out_axis = 0 if out_axis == -1 else -1
return WeightQBytesTensor(t.qtype, out_axis, out_size, out_stride, out_data, out_scale, t.activation_qtype)
# No dispatch available: qfallback
kwargs = kwargs or {}
return qfallback(op, *args, **kwargs)