def _load_model()

in d2go/modeling/quantization.py [0:0]


    def _load_model(self, checkpoint):
        model_is_qat = self._is_q_state_dict(self.model.state_dict())
        checkpoint_is_qat = self._is_q_state_dict(checkpoint["model"])

        if model_is_qat and not checkpoint_is_qat:
            logger.info("Loading QAT model with non-QAT checkpoint, ignore observers!")
            mapping = getattr(self.model, "_non_qat_to_qat_state_dict_map", {})

            # map the key from non-QAT model to QAT model if possible
            checkpoint_state_dict = {
                mapping.get(k, k): v for k, v in checkpoint["model"].items()
            }
            checkpoint["model"] = checkpoint_state_dict
            incompatible = super()._load_model(checkpoint)

            # suppress the missing observer keys warning
            # NOTE: for some reason incompatible.missing_keys can have duplicated keys,
            # here we replace the entire list rather than calling .remove()
            missing_non_qat_keys = [
                k for k in incompatible.missing_keys if not _is_observer_key(k)
            ]
            incompatible.missing_keys[:] = missing_non_qat_keys

            return incompatible

        elif not model_is_qat and checkpoint_is_qat:
            raise NotImplementedError()
        elif model_is_qat and checkpoint_is_qat:
            # TODO: maybe suppress shape mismatch
            # For models trained with QAT and per-channel quant, the inital size of the
            # buffers in fake_quant and observer modules does not reflect the size in
            # state_dict, which is updated only when convert is called.
            return super()._load_model(checkpoint)
        else:
            return super()._load_model(checkpoint)