def make_tree_plot()

in data_measurements/embeddings/embeddings.py [0:0]


def make_tree_plot(node_list, nid_map, text_dset, text_field_name):
    """
    Makes a graphical representation of the tree encoded
    in node-list. The hover label for each node shows the number
    of descendants and the 5 examples that are closest to the centroid
    """
    for nid, node in enumerate(node_list):
        # get list of
        node_examples = {}
        for sid, score in node["sorted_examples_centroid"]:
            node_examples[text_dset[sid][text_field_name]] = score
            if len(node_examples) >= 5:
                break
        node["label"] = node.get(
            "label",
            f"{nid:2d} - {node['weight']:5d} items <br>"
            + "<br>".join(
                [
                    f" {score:.2f} > {txt[:64]}" + ("..." if len(txt) >= 63 else "")
                    for txt, score in node_examples.items()
                ]
            ),
        )

    # make plot nodes
    labels = [node["label"] for node in node_list]

    root = node_list[0]
    root["X"] = 0
    root["Y"] = 0

    def rec_make_coordinates(node):
        total_weight = 0
        add_weight = len(node["example_ids"]) - sum(
            [child["weight"] for child in node["children"]]
        )
        for child in node["children"]:
            child["X"] = node["X"] + total_weight
            child["Y"] = node["Y"] - 1
            total_weight += child["weight"] + add_weight / len(node["children"])
            rec_make_coordinates(child)

    rec_make_coordinates(root)

    E = []  # list of edges
    Xn = []
    Yn = []
    Xe = []
    Ye = []
    for nid, node in enumerate(node_list):
        Xn += [node["X"]]
        Yn += [node["Y"]]
        for child in node["children"]:
            E += [(nid, nid_map[child["nid"]])]
            Xe += [node["X"], child["X"], None]
            Ye += [node["Y"], child["Y"], None]

    # make figure
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=Xe,
            y=Ye,
            mode="lines",
            line=dict(color="rgb(210,210,210)", width=1),
            hoverinfo="none",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=Xn,
            y=Yn,
            mode="markers",
            name="nodes",
            marker=dict(
                symbol="circle-dot",
                size=18,
                color="#6175c1",
                line=dict(color="rgb(50,50,50)", width=1)
                # '#DB4551',
            ),
            text=labels,
            hoverinfo="text",
            opacity=0.8,
        )
    )
    return fig