in empchat/models.py [0:0]
def load(filename, new_opt):
logging.info(f"Loading model {filename}")
saved_params = torch.load(filename, map_location=lambda storage, loc: storage)
word_dict = saved_params["word_dict"]
state_dict = saved_params["state_dict"]
saved_opt = saved_params["opt"]
for k, v in vars(new_opt).items():
if not hasattr(saved_opt, k):
logging.warning(f"Setting {k} to {v}")
setattr(saved_opt, k, v)
if not (hasattr(new_opt, "fasttext")):
setattr(saved_opt, "fasttext", new_opt.fasttext)
if new_opt.model == "bert":
assert "bert_tokenizer" in word_dict
net = create(saved_opt, word_dict["words"])
net.load_state_dict(state_dict, strict=False)
return net, word_dict