in drqa/reader/model.py [0:0]
def tune_embeddings(self, words):
"""Unfix the embeddings of a list of words. This is only relevant if
only some of the embeddings are being tuned (tune_partial = N).
Shuffles the N specified words to the front of the dictionary, and saves
the original vectors of the other N + 1:vocab words in a fixed buffer.
Args:
words: iterable of tokens contained in dictionary.
"""
words = {w for w in words if w in self.word_dict}
if len(words) == 0:
logger.warning('Tried to tune embeddings, but no words given!')
return
if len(words) == len(self.word_dict):
logger.warning('Tuning ALL embeddings in dictionary')
return
# Shuffle words and vectors
embedding = self.network.embedding.weight.data
for idx, swap_word in enumerate(words, self.word_dict.START):
# Get current word + embedding for this index
curr_word = self.word_dict[idx]
curr_emb = embedding[idx].clone()
old_idx = self.word_dict[swap_word]
# Swap embeddings + dictionary indices
embedding[idx].copy_(embedding[old_idx])
embedding[old_idx].copy_(curr_emb)
self.word_dict[swap_word] = idx
self.word_dict[idx] = swap_word
self.word_dict[curr_word] = old_idx
self.word_dict[old_idx] = curr_word
# Save the original, fixed embeddings
self.network.register_buffer(
'fixed_embedding', embedding[idx + 1:].clone()
)