in data_measurements/embeddings/embeddings.py [0:0]
def merge_nodes(nodes, current_thres, previous_thres, all_merges, all_merge_scores):
"""
Merge all nodes if the max dot product between any of their descendants
is greater than current_thres.
Args:
nodes ([dict]): list of dicts representing the current set of nodes
current_thres (float): merge all nodes closer than current_thres
previous_thres (float): nodes closer than previous_thres are already merged
all_merges (torch.LongTensor): proposed merges ([i, j] with i>j) - dimension: Mx2
all_merge_scores (torch.Tensor): merge scores - dimension M
Returns:
[dict]: extended list with the newly created internal nodes
"""
merge_ids = (all_merge_scores <= previous_thres) * (
all_merge_scores > current_thres
)
if merge_ids.sum().item() > 0:
merges = all_merges[merge_ids]
for a, b in merges.tolist():
node_a = nodes[a]
while node_a["parent_id"] != -1:
node_a = nodes[node_a["parent_id"]]
node_b = nodes[b]
while node_b["parent_id"] != -1:
node_b = nodes[node_b["parent_id"]]
if node_a["nid"] == node_b["nid"]:
continue
else:
# merge if threshold allows
if (node_a["depth"] + node_b["depth"]) > 0 and min(
node_a["merge_threshold"], node_b["merge_threshold"]
) == current_thres:
merge_to = None
merge_from = None
if node_a["nid"] < node_b["nid"]:
merge_from = node_a
merge_to = node_b
if node_a["nid"] > node_b["nid"]:
merge_from = node_b
merge_to = node_a
merge_to["depth"] = max(merge_to["depth"], merge_from["depth"])
merge_to["weight"] += merge_from["weight"]
merge_to["children_ids"] += (
merge_from["children_ids"]
if merge_from["depth"] > 0
else [merge_from["nid"]]
)
for cid in merge_from["children_ids"]:
nodes[cid]["parent_id"] = merge_to["nid"]
merge_from["parent_id"] = merge_to["nid"]
# else new node
else:
new_nid = len(nodes)
new_node = {
"nid": new_nid,
"parent_id": -1,
"depth": max(node_a["depth"], node_b["depth"]) + 1,
"weight": node_a["weight"] + node_b["weight"],
"children": [],
"children_ids": [node_a["nid"], node_b["nid"]],
"example_ids": [],
"merge_threshold": current_thres,
}
node_a["parent_id"] = new_nid
node_b["parent_id"] = new_nid
nodes += [new_node]
return nodes