def show_venn_diagram()

in identity-resolution/notebooks/identity-graph/nepytune/visualizations/venn_diagram.py [0:0]


def show_venn_diagram(intersections, labels):
    def point_on_triangle(pt1, pt2, pt3):
        """
        Random point on the triangle with vertices pt1, pt2 and pt3.
        """
        s, t = sorted([random.random(), random.random()])
        return (s * pt1[0] + (t - s) * pt2[0] + (1 - t) * pt3[0],
                s * pt1[1] + (t - s) * pt2[1] + (1 - t) * pt3[1])

    def area(tri):
        y_list = [tri[0][1], tri[1][1], tri[2][1]]
        x_list = [tri[0][0], tri[1][0], tri[2][0]]
        height = max(y_list) - min(y_list)
        width = max(x_list) - min(x_list)
        return height * width / 2

    empty_sets = [k for k, v in intersections.items() if not len(v)]

    if empty_sets:
        raise ValueError(f"Given intersections \"{empty_sets}\" are empty, cannot continue")

    scatters = []

    for k, v in intersections.items():
        weights = [area(triangle) for triangle in TRIANGLES[k]]
        points_pairs = [point_on_triangle(*random.choices(TRIANGLES[k], weights=weights)[0]) for _ in v]
        x, y = zip(*points_pairs)
        scatter_labels = [make_label(n) for n in v]

        scatters.append(
            go.Scatter(
                x=x,
                y=y,
                mode='markers',
                showlegend=False,
                text=scatter_labels,
                marker=dict(
                    size=10,
                    line=dict(width=2,
                              color='DarkSlateGrey'),
                    opacity=1,
                ),
                hoverinfo="text",
            )
        )
    fig = go.Figure(
        data=list(scatters),
        layout=go.Layout(
            title_text="",
            autosize=False,
            titlefont_size=16,
            showlegend=True,
            hovermode='closest',
            margin=dict(b=20, l=5, r=5, t=40),
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, scaleanchor="x", scaleratio=1)
        ),
    )

    fig.update_layout(
        shapes=[
            go.layout.Shape(
                type="circle",
                x0=0,
                y0=-6,
                x1=12,
                y1=6,
                fillcolor="Red",
                opacity=0.15,
                layer='below'
            ),
            go.layout.Shape(
                type="circle",
                x0=8,
                y0=-6,
                x1=20,
                y1=6,
                fillcolor="Blue",
                opacity=0.15,
                layer='below'
            ),
            go.layout.Shape(
                type="circle",
                x0=4,
                y0=-12,
                x1=16,
                y1=0,
                fillcolor="Green",
                opacity=0.15,
                layer='below'
            ),
        ]
    )

    fig.update_layout(
        annotations=[
            dict(
                xref="x",
                yref="y",
                x=6, y=6,
                text=labels[0],
                font=dict(size=15),
                showarrow=True,
                arrowwidth=2,
                ax=-50,
                ay=-25,
                arrowhead=7,
            ),
            dict(
                xref="x",
                yref="y",
                x=14, y=6,
                text=labels[1],
                font=dict(size=15),
                showarrow=True,
                arrowwidth=2,
                ax=50,
                ay=-25,
                arrowhead=7,
            ),
            dict(
                xref="x",
                yref="y",
                x=10, y=-12,
                text=labels[2],
                font=dict(size=15),
                showarrow=True,
                arrowwidth=2,
                ax=50,
                ay=25,
                arrowhead=7,
            ),
        ]
    )

    fig.show()