in ubteacher/checkpoint/detection_checkpoint.py [0:0]
def _load_student_model(self, checkpoint: Any) -> _IncompatibleKeys: # pyre-ignore
checkpoint_state_dict = checkpoint.pop("model")
self._convert_ndarray_to_tensor(checkpoint_state_dict)
# if the state_dict comes from a model that was wrapped in a
# DataParallel or DistributedDataParallel during serialization,
# remove the "module" prefix before performing the matching.
_strip_prefix_if_present(checkpoint_state_dict, "module.")
# work around https://github.com/pytorch/pytorch/issues/24139
model_state_dict = self.model.modelStudent.state_dict()
incorrect_shapes = []
for k in list(checkpoint_state_dict.keys()):
if k in model_state_dict:
shape_model = tuple(model_state_dict[k].shape)
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
if shape_model != shape_checkpoint:
incorrect_shapes.append((k, shape_checkpoint, shape_model))
checkpoint_state_dict.pop(k)
# pyre-ignore
incompatible = self.model.modelStudent.load_state_dict(
checkpoint_state_dict, strict=False
)
return _IncompatibleKeys(
missing_keys=incompatible.missing_keys,
unexpected_keys=incompatible.unexpected_keys,
incorrect_shapes=incorrect_shapes,
)