in trainers/catex.py [0:0]
def load_model(self, directory, epoch=None):
if not directory:
print("Note that load_model() is skipped as no pretrained model is given")
return
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError('Model not found at "{}"'.format(model_path))
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
epoch = checkpoint["epoch"]
# Ignore fixed token vectors
model_dict = self._models[name].state_dict()
if "token_prefix" in state_dict:
del state_dict["token_prefix"]
if "token_suffix" in state_dict:
del state_dict["token_suffix"]
assert all(k in model_dict for k in state_dict)
print("Loading weights to {} {} " 'from "{}" (epoch = {})'.format(name, list(state_dict.keys()), model_path, epoch))
# set strict=False
self._models[name].load_state_dict(state_dict, strict=False)
if self.cfg.TRAINER.CATEX.CTX_INIT:
assert self.cfg.TRAINER.CATEX.CTX_INIT == 'ensemble_learned'
text_feature = self.model.get_text_features()
self.model.text_feature_ensemble = self.model.prompt_ensemble(text_feature)