in utils/corpus_utils.py [0:0]
def load_embedding(embedding_path, embedding_dim, format, file_type, with_head=False, word_set=None):
"""
Args:
format: 'glove', 'word2vec', 'fasttext'
file_type: 'text' or 'binary'
"""
embedding_dict = dict()
if format == 'word2vec' or format == 'fasttext':
if file_type == 'text':
vector_total = KeyedVectors.load_word2vec_format(embedding_path, binary=False, unicode_errors='ignore')
else:
if format == 'word2vec':
vector_total = KeyedVectors.load_word2vec_format(embedding_path, binary=True, unicode_errors='ignore')
elif format == 'fasttext':
vector_total = FastText.load_fasttext_format(embedding_path, encoding='utf8')
assert vector_total.vector_size == embedding_dim
if word_set is None:
embedding_dict = vector_total
else:
if not (format == 'fasttext' and file_type == 'binary'):
word_total = vector_total.index2word # actually, vector_total.index2word is the word list
else:
word_total = vector_total.wv.index2word
for word in word_total:
if word in word_set:
embedding_dict[word] = vector_total[word]
elif format == 'glove':
with codecs.open(embedding_path, 'r', encoding='utf-8') as fin:
if with_head == True:
_ = fin.readline()
for idx, line in enumerate(fin):
line = line.rstrip()
if idx == 0 and len(line.split()) == 2:
continue
if len(line) > 0:
word, vec = line.split(" ", 1)
if (word_set and word in word_set) or (word_set is None):
vector = [float(num) for num in vec.split(" ")]
assert len(vector) == embedding_dim
embedding_dict[word] = vector
else:
raise Exception('The format supported are glove, word2vec or fasttext, dost not support %s now.' % format)
return embedding_dict