def load_model()

in trainers/catex.py [0:0]


    def load_model(self, directory, epoch=None):
        if not directory:
            print("Note that load_model() is skipped as no pretrained model is given")
            return

        names = self.get_model_names()

        # By default, the best model is loaded
        model_file = "model-best.pth.tar"

        if epoch is not None:
            model_file = "model.pth.tar-" + str(epoch)

        for name in names:
            model_path = osp.join(directory, name, model_file)

            if not osp.exists(model_path):
                raise FileNotFoundError('Model not found at "{}"'.format(model_path))

            checkpoint = load_checkpoint(model_path)
            state_dict = checkpoint["state_dict"]
            epoch = checkpoint["epoch"]

            # Ignore fixed token vectors
            model_dict = self._models[name].state_dict()
            if "token_prefix" in state_dict:
                del state_dict["token_prefix"]
            if "token_suffix" in state_dict:
                del state_dict["token_suffix"]
            assert all(k in model_dict for k in state_dict)

            print("Loading weights to {} {} " 'from "{}" (epoch = {})'.format(name, list(state_dict.keys()), model_path, epoch))
            # set strict=False
            self._models[name].load_state_dict(state_dict, strict=False)

        if self.cfg.TRAINER.CATEX.CTX_INIT:
            assert self.cfg.TRAINER.CATEX.CTX_INIT == 'ensemble_learned'
            text_feature = self.model.get_text_features()
            self.model.text_feature_ensemble = self.model.prompt_ensemble(text_feature)