def merge_nodes()

in data_measurements/embeddings/embeddings.py [0:0]


def merge_nodes(nodes, current_thres, previous_thres, all_merges, all_merge_scores):
    """
    Merge all nodes if the max dot product between any of their descendants
    is greater than current_thres.

    Args:
        nodes ([dict]): list of dicts representing the current set of nodes
        current_thres (float): merge all nodes closer than current_thres
        previous_thres (float): nodes closer than previous_thres are already merged
        all_merges (torch.LongTensor): proposed merges ([i, j] with i>j) - dimension: Mx2
        all_merge_scores (torch.Tensor): merge scores - dimension M
    Returns:
        [dict]: extended list with the newly created internal nodes
    """
    merge_ids = (all_merge_scores <= previous_thres) * (
        all_merge_scores > current_thres
    )
    if merge_ids.sum().item() > 0:
        merges = all_merges[merge_ids]
        for a, b in merges.tolist():
            node_a = nodes[a]
            while node_a["parent_id"] != -1:
                node_a = nodes[node_a["parent_id"]]
            node_b = nodes[b]
            while node_b["parent_id"] != -1:
                node_b = nodes[node_b["parent_id"]]
            if node_a["nid"] == node_b["nid"]:
                continue
            else:
                # merge if threshold allows
                if (node_a["depth"] + node_b["depth"]) > 0 and min(
                    node_a["merge_threshold"], node_b["merge_threshold"]
                ) == current_thres:
                    merge_to = None
                    merge_from = None
                    if node_a["nid"] < node_b["nid"]:
                        merge_from = node_a
                        merge_to = node_b
                    if node_a["nid"] > node_b["nid"]:
                        merge_from = node_b
                        merge_to = node_a
                    merge_to["depth"] = max(merge_to["depth"], merge_from["depth"])
                    merge_to["weight"] += merge_from["weight"]
                    merge_to["children_ids"] += (
                        merge_from["children_ids"]
                        if merge_from["depth"] > 0
                        else [merge_from["nid"]]
                    )
                    for cid in merge_from["children_ids"]:
                        nodes[cid]["parent_id"] = merge_to["nid"]
                    merge_from["parent_id"] = merge_to["nid"]
                # else new node
                else:
                    new_nid = len(nodes)
                    new_node = {
                        "nid": new_nid,
                        "parent_id": -1,
                        "depth": max(node_a["depth"], node_b["depth"]) + 1,
                        "weight": node_a["weight"] + node_b["weight"],
                        "children": [],
                        "children_ids": [node_a["nid"], node_b["nid"]],
                        "example_ids": [],
                        "merge_threshold": current_thres,
                    }
                    node_a["parent_id"] = new_nid
                    node_b["parent_id"] = new_nid
                    nodes += [new_node]
    return nodes