def load()

in empchat/models.py [0:0]


def load(filename, new_opt):
    logging.info(f"Loading model {filename}")
    saved_params = torch.load(filename, map_location=lambda storage, loc: storage)
    word_dict = saved_params["word_dict"]
    state_dict = saved_params["state_dict"]
    saved_opt = saved_params["opt"]
    for k, v in vars(new_opt).items():
        if not hasattr(saved_opt, k):
            logging.warning(f"Setting {k} to {v}")
            setattr(saved_opt, k, v)
    if not (hasattr(new_opt, "fasttext")):
        setattr(saved_opt, "fasttext", new_opt.fasttext)
    if new_opt.model == "bert":
        assert "bert_tokenizer" in word_dict
    net = create(saved_opt, word_dict["words"])
    net.load_state_dict(state_dict, strict=False)
    return net, word_dict