in empchat/models.py [0:0]
def load_embeddings(opt, dictionary, model):
path = opt.embeddings
logging.info(f"Loading embeddings file from {path}")
emb_table = model.embeddings.weight
requires_grad = emb_table.requires_grad
emb_table[dictionary[PAD_TOKEN]].zero_() # Zero-out padding index
n_added = 0
missing_dict = set(dictionary.keys())
with open(path) as f:
for line in f:
parsed = line.rstrip().split(" ")
assert len(parsed) == opt.embeddings_size + 1
w = parsed[0]
if w in dictionary:
n_added += 1
vec = torch.Tensor([float(i) for i in parsed[1:]])
if opt.normalize_emb:
vec = vec / vec.norm(2)
emb_table.data[dictionary[w]] = vec
missing_dict.remove(w)
sample = ", ".join(list(missing_dict)[:8])
logging.info(
f"Loaded {n_added} vectors from embeddings file; {len(missing_dict)} are "
f"missing, among which: {sample}"
)
emb_table.detach_()
emb_table.requires_grad = requires_grad