in diffq/base.py [0:0]
def restore_quantized_state(self, state) -> None:
"""
Restore the state of the model from the quantized state.
"""
for p, q in zip(self._float16, state["float16"]):
p.data[:] = q.to(p)
for p, q in zip(self._others, state["others"]):
p.data[:] = q
meta = state.get("meta", {})
packed = meta.get("packed", False)
torch_pack = meta.get("torch_pack", False)
if torch_pack:
unpack_fn = torch_pack_mod.unpack
else:
unpack_fn = bitpack.unpack
remaining = list(state["quantized"])
for qparam in self._qparams:
if qparam.other is not None:
# Only unquantize first appearance of nn.Parameter.
continue
quantized = remaining.pop(0)
if packed:
quantized = self._bit_unpack_param(qparam, quantized, unpack_fn)
qparam.param.data[:] = self._unquantize_param(qparam, quantized)
assert not remaining
self._fix_rnns()