in projects/krisp/graphnetwork_module.py [0:0]
def prepare_embeddings(node_names, embedding_file, add_split):
"""
This function is used to prepare embeddings for the graph
:param embedding_file: location of the raw embedding file
:return:
"""
print("\n\nCreating node embeddings...")
embedding_model = ""
if "glove" in embedding_file:
embedding_model = "glove"
elif "GoogleNews" in embedding_file:
embedding_model = "word2vec"
elif "subword" in embedding_file:
embedding_model = "fasttext"
elif "numberbatch" in embedding_file:
embedding_model = "numberbatch"
def transform(compound_word):
return [
compound_word,
"_".join([w.lower() for w in compound_word.split(" ")]),
"_".join([w.capitalize() for w in compound_word.split(" ")]),
"-".join([w for w in compound_word.split(" ")]),
"-".join([w for w in compound_word.split(" ")]),
]
node2vec = {}
model = None
# glove has a slightly different format
if embedding_model == "glove":
tmp_file = ".".join(embedding_file.split(".")[:-1]) + "_tmp.txt"
glove2word2vec(embedding_file, tmp_file)
embedding_file = tmp_file
# Important: only native word2vec file needs binary flag to be true
print(f"Loading pretrained embeddings from {embedding_file} ...")
model = KeyedVectors.load_word2vec_format(
embedding_file, binary=(embedding_model == "word2vec")
)
# retrieve embeddings for graph nodes
no_match_nodes = []
match_positions = []
for node_name in tqdm(node_names, desc="Prepare node embeddings"):
try_words = []
try_words.extend(transform(node_name))
# Try to find w2v
found_mapping = False
for i, try_word in enumerate(try_words):
try:
node2vec[node_name] = model.get_vector(try_word)
match_positions.append(i + 1)
found_mapping = True
except KeyError:
pass
if found_mapping:
break
# Try multi-words (average w2v)
if add_split:
if not found_mapping and len(node_name.split(" ")) > 1:
sub_word_vecs = []
for subword in node_name.split(" "):
# Get w2v for the individual words
try_words = []
try_words.extend(transform(subword))
mp = []
found_submap = False
for i, try_word in enumerate(try_words):
try:
sub_word_vecs.append(model.get_vector(try_word))
mp.append(i + 1)
found_submap = True
except KeyError:
pass
if found_submap:
break
# If all subswords successful, add it to node2vec and match_positions
if len(sub_word_vecs) == len(node_name.split(" ")):
node2vec[node_name] = np.mean(sub_word_vecs, 0)
match_positions.append(
np.mean(mp)
) # I'm sort of ignoring match_positions except for counts
found_mapping = True
else:
if not found_mapping and len(node_name.split("_")) > 1:
sub_word_vecs = []
for subword in node_name.split("_"):
# Get w2v for the individual words
try_words = []
try_words.extend(transform(subword))
mp = []
found_submap = False
for i, try_word in enumerate(try_words):
try:
sub_word_vecs.append(model.get_vector(try_word))
mp.append(i + 1)
found_submap = True
except KeyError:
pass
if found_submap:
break
# If all subswords successful, add it to node2vec and match_positions
if len(sub_word_vecs) == len(node_name.split("_")):
node2vec[node_name] = np.mean(sub_word_vecs, 0)
match_positions.append(
np.mean(mp)
) # I'm sort of ignoring match_positions except for counts
found_mapping = True
# All else fails, it's a no match
if not found_mapping:
no_match_nodes.append([node_name, try_words])