in dpr_scale/task/dpr_task.py [0:0]
def setup(self, stage: str):
# skip building model during test.
# Otherwise, the state dict will be re-initialized
if stage == "test" and self.setup_done:
return
# resetting call_configure_sharded_model_hook attribute so that we could configure model
self.call_configure_sharded_model_hook = False
self.query_encoder = hydra.utils.instantiate(
self.model_conf,
)
if self.shared_model:
self.context_encoder = self.query_encoder
else:
self.context_encoder = hydra.utils.instantiate(
self.model_conf,
)
if self.pretrained_checkpoint_path:
checkpoint_dict = torch.load(
PathManager.open(self.pretrained_checkpoint_path, "rb"),
map_location=lambda s, l: default_restore_location(s, "cpu"),
)
self.load_state_dict(checkpoint_dict["state_dict"])
print(f"Loaded state dict from {self.pretrained_checkpoint_path}")
self.setup_done = True