in data_measurements/embeddings/embeddings.py [0:0]
def make_tree_plot(node_list, nid_map, text_dset, text_field_name):
"""
Makes a graphical representation of the tree encoded
in node-list. The hover label for each node shows the number
of descendants and the 5 examples that are closest to the centroid
"""
for nid, node in enumerate(node_list):
# get list of
node_examples = {}
for sid, score in node["sorted_examples_centroid"]:
node_examples[text_dset[sid][text_field_name]] = score
if len(node_examples) >= 5:
break
node["label"] = node.get(
"label",
f"{nid:2d} - {node['weight']:5d} items <br>"
+ "<br>".join(
[
f" {score:.2f} > {txt[:64]}" + ("..." if len(txt) >= 63 else "")
for txt, score in node_examples.items()
]
),
)
# make plot nodes
labels = [node["label"] for node in node_list]
root = node_list[0]
root["X"] = 0
root["Y"] = 0
def rec_make_coordinates(node):
total_weight = 0
add_weight = len(node["example_ids"]) - sum(
[child["weight"] for child in node["children"]]
)
for child in node["children"]:
child["X"] = node["X"] + total_weight
child["Y"] = node["Y"] - 1
total_weight += child["weight"] + add_weight / len(node["children"])
rec_make_coordinates(child)
rec_make_coordinates(root)
E = [] # list of edges
Xn = []
Yn = []
Xe = []
Ye = []
for nid, node in enumerate(node_list):
Xn += [node["X"]]
Yn += [node["Y"]]
for child in node["children"]:
E += [(nid, nid_map[child["nid"]])]
Xe += [node["X"], child["X"], None]
Ye += [node["Y"], child["Y"], None]
# make figure
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=Xe,
y=Ye,
mode="lines",
line=dict(color="rgb(210,210,210)", width=1),
hoverinfo="none",
)
)
fig.add_trace(
go.Scatter(
x=Xn,
y=Yn,
mode="markers",
name="nodes",
marker=dict(
symbol="circle-dot",
size=18,
color="#6175c1",
line=dict(color="rgb(50,50,50)", width=1)
# '#DB4551',
),
text=labels,
hoverinfo="text",
opacity=0.8,
)
)
return fig