in d2go/modeling/quantization.py [0:0]
def _load_model(self, checkpoint):
model_is_qat = self._is_q_state_dict(self.model.state_dict())
checkpoint_is_qat = self._is_q_state_dict(checkpoint["model"])
if model_is_qat and not checkpoint_is_qat:
logger.info("Loading QAT model with non-QAT checkpoint, ignore observers!")
mapping = getattr(self.model, "_non_qat_to_qat_state_dict_map", {})
# map the key from non-QAT model to QAT model if possible
checkpoint_state_dict = {
mapping.get(k, k): v for k, v in checkpoint["model"].items()
}
checkpoint["model"] = checkpoint_state_dict
incompatible = super()._load_model(checkpoint)
# suppress the missing observer keys warning
# NOTE: for some reason incompatible.missing_keys can have duplicated keys,
# here we replace the entire list rather than calling .remove()
missing_non_qat_keys = [
k for k in incompatible.missing_keys if not _is_observer_key(k)
]
incompatible.missing_keys[:] = missing_non_qat_keys
return incompatible
elif not model_is_qat and checkpoint_is_qat:
raise NotImplementedError()
elif model_is_qat and checkpoint_is_qat:
# TODO: maybe suppress shape mismatch
# For models trained with QAT and per-channel quant, the inital size of the
# buffers in fake_quant and observer modules does not reflect the size in
# state_dict, which is updated only when convert is called.
return super()._load_model(checkpoint)
else:
return super()._load_model(checkpoint)