in optimum/quanto/tensor/qbits.py [0:0]
def forward(ctx, t):
if isinstance(t._data, PackedTensor):
data = t._data.unpack()
else:
data = t._data
shift = t._shift
if not shift.dtype.is_floating_point:
# Remove shift before multiplying by the scale
data = data.to(torch.int8) - shift.to(torch.int8)
if t.qtype.is_floating_point:
# Upcast explicitly to the scale dtype
dqt = t._scale * data.to(t._scale.dtype)
else:
dqt = t._scale * data
if shift.dtype.is_floating_point:
# Remove scaled shift
dqt -= shift
if t.axis is None:
return dqt
# Restore the original shape (if needed)
return ungroup(dqt, axis=t.axis, orig_shape=t.shape)