in data.py [0:0]
def __init__(self, path, sort_dict=False):
self.word2idx = {}
self.word2count = {}
self.idx2word = []
assert os.path.exists(path)
# Add words to the dictionary
with open(path, 'r', encoding="utf8") as f:
for line in f:
words = line.split() + ['<eos>']
for word in words:
if sort_dict:
self.word2count[word] = self.word2count.get(word, 0) + 1
elif word not in self.word2idx:
self.word2idx[word] = len(self.idx2word)
self.idx2word.append(word)
if sort_dict:
# Sort dictionary by count and build indices accordingly:
sorted_dict = sorted(self.word2count.items(), key=lambda kv: kv[1])[::-1]
for i in range(len(sorted_dict)):
word = sorted_dict[i][0]
self.word2idx[word] = i
self.idx2word.append(word)