in lama/vocab_intersection.py [0:0]
def __vocab_intersection(models, filename):
vocabularies = []
for arg_dict in models:
args = argparse.Namespace(**arg_dict)
print(args)
model = build_model_by_name(args.lm, args)
vocabularies.append(model.vocab)
print(type(model.vocab))
if len(vocabularies) > 0:
common_vocab = set(vocabularies[0])
for vocab in vocabularies:
common_vocab = common_vocab.intersection(set(vocab))
# no special symbols in common_vocab
for symbol in base.SPECIAL_SYMBOLS:
if symbol in common_vocab:
common_vocab.remove(symbol)
# remove stop words
from spacy.lang.en.stop_words import STOP_WORDS
for stop_word in STOP_WORDS:
if stop_word in common_vocab:
print(stop_word)
common_vocab.remove(stop_word)
common_vocab = list(common_vocab)
# remove punctuation and symbols
nlp = spacy.load('en')
manual_punctuation = ['(', ')', '.', ',']
new_common_vocab = []
for i in tqdm(range(len(common_vocab))):
word = common_vocab[i]
doc = nlp(word)
token = doc[0]
if(len(doc) != 1):
print(word)
for idx, tok in enumerate(doc):
print("{} - {}".format(idx, tok))
elif word in manual_punctuation:
pass
elif token.pos_ == "PUNCT":
print("PUNCT: {}".format(word))
elif token.pos_ == "SYM":
print("SYM: {}".format(word))
else:
new_common_vocab.append(word)
# print("{} - {}".format(word, token.pos_))
common_vocab = new_common_vocab
# store common_vocab on file
with open(filename, 'w') as f:
for item in sorted(common_vocab):
f.write("{}\n".format(item))