def forward()

in optimum/quanto/tensor/qbits.py [0:0]


    def forward(ctx, t):
        if isinstance(t._data, PackedTensor):
            data = t._data.unpack()
        else:
            data = t._data
        shift = t._shift
        if not shift.dtype.is_floating_point:
            # Remove shift before multiplying by the scale
            data = data.to(torch.int8) - shift.to(torch.int8)
        if t.qtype.is_floating_point:
            # Upcast explicitly to the scale dtype
            dqt = t._scale * data.to(t._scale.dtype)
        else:
            dqt = t._scale * data
        if shift.dtype.is_floating_point:
            # Remove scaled shift
            dqt -= shift
        if t.axis is None:
            return dqt
        # Restore the original shape (if needed)
        return ungroup(dqt, axis=t.axis, orig_shape=t.shape)