def prepare_embeddings()

in projects/krisp/graphnetwork_module.py [0:0]


def prepare_embeddings(node_names, embedding_file, add_split):
    """
    This function is used to prepare embeddings for the graph
    :param embedding_file: location of the raw embedding file
    :return:
    """
    print("\n\nCreating node embeddings...")

    embedding_model = ""
    if "glove" in embedding_file:
        embedding_model = "glove"
    elif "GoogleNews" in embedding_file:
        embedding_model = "word2vec"
    elif "subword" in embedding_file:
        embedding_model = "fasttext"
    elif "numberbatch" in embedding_file:
        embedding_model = "numberbatch"

    def transform(compound_word):
        return [
            compound_word,
            "_".join([w.lower() for w in compound_word.split(" ")]),
            "_".join([w.capitalize() for w in compound_word.split(" ")]),
            "-".join([w for w in compound_word.split(" ")]),
            "-".join([w for w in compound_word.split(" ")]),
        ]

    node2vec = {}
    model = None

    # glove has a slightly different format
    if embedding_model == "glove":
        tmp_file = ".".join(embedding_file.split(".")[:-1]) + "_tmp.txt"
        glove2word2vec(embedding_file, tmp_file)
        embedding_file = tmp_file

    # Important: only native word2vec file needs binary flag to be true
    print(f"Loading pretrained embeddings from {embedding_file} ...")
    model = KeyedVectors.load_word2vec_format(
        embedding_file, binary=(embedding_model == "word2vec")
    )

    # retrieve embeddings for graph nodes
    no_match_nodes = []
    match_positions = []

    for node_name in tqdm(node_names, desc="Prepare node embeddings"):
        try_words = []
        try_words.extend(transform(node_name))

        # Try to find w2v
        found_mapping = False
        for i, try_word in enumerate(try_words):
            try:
                node2vec[node_name] = model.get_vector(try_word)
                match_positions.append(i + 1)
                found_mapping = True
            except KeyError:
                pass
            if found_mapping:
                break

        # Try multi-words (average w2v)
        if add_split:
            if not found_mapping and len(node_name.split(" ")) > 1:
                sub_word_vecs = []
                for subword in node_name.split(" "):
                    # Get w2v for the individual words
                    try_words = []
                    try_words.extend(transform(subword))
                    mp = []
                    found_submap = False
                    for i, try_word in enumerate(try_words):
                        try:
                            sub_word_vecs.append(model.get_vector(try_word))
                            mp.append(i + 1)
                            found_submap = True
                        except KeyError:
                            pass
                        if found_submap:
                            break

                # If all subswords successful, add it to node2vec and match_positions
                if len(sub_word_vecs) == len(node_name.split(" ")):
                    node2vec[node_name] = np.mean(sub_word_vecs, 0)
                    match_positions.append(
                        np.mean(mp)
                    )  # I'm sort of ignoring match_positions except for counts
                    found_mapping = True
        else:
            if not found_mapping and len(node_name.split("_")) > 1:
                sub_word_vecs = []
                for subword in node_name.split("_"):
                    # Get w2v for the individual words
                    try_words = []
                    try_words.extend(transform(subword))
                    mp = []
                    found_submap = False
                    for i, try_word in enumerate(try_words):
                        try:
                            sub_word_vecs.append(model.get_vector(try_word))
                            mp.append(i + 1)
                            found_submap = True
                        except KeyError:
                            pass
                        if found_submap:
                            break

                # If all subswords successful, add it to node2vec and match_positions
                if len(sub_word_vecs) == len(node_name.split("_")):
                    node2vec[node_name] = np.mean(sub_word_vecs, 0)
                    match_positions.append(
                        np.mean(mp)
                    )  # I'm sort of ignoring match_positions except for counts
                    found_mapping = True

        # All else fails, it's a no match
        if not found_mapping:
            no_match_nodes.append([node_name, try_words])