def setup()

in dpr_scale/task/dpr_task.py [0:0]


    def setup(self, stage: str):
        # skip building model during test.
        # Otherwise, the state dict will be re-initialized
        if stage == "test" and self.setup_done:
            return
        # resetting call_configure_sharded_model_hook attribute so that we could configure model
        self.call_configure_sharded_model_hook = False

        self.query_encoder = hydra.utils.instantiate(
            self.model_conf,
        )
        if self.shared_model:
            self.context_encoder = self.query_encoder
        else:
            self.context_encoder = hydra.utils.instantiate(
                self.model_conf,
            )

        if self.pretrained_checkpoint_path:
            checkpoint_dict = torch.load(
                PathManager.open(self.pretrained_checkpoint_path, "rb"),
                map_location=lambda s, l: default_restore_location(s, "cpu"),
            )
            self.load_state_dict(checkpoint_dict["state_dict"])
            print(f"Loaded state dict from {self.pretrained_checkpoint_path}")

        self.setup_done = True