in identity-resolution/notebooks/identity-graph/nepytune/visualizations/venn_diagram.py [0:0]
def show_venn_diagram(intersections, labels):
def point_on_triangle(pt1, pt2, pt3):
"""
Random point on the triangle with vertices pt1, pt2 and pt3.
"""
s, t = sorted([random.random(), random.random()])
return (s * pt1[0] + (t - s) * pt2[0] + (1 - t) * pt3[0],
s * pt1[1] + (t - s) * pt2[1] + (1 - t) * pt3[1])
def area(tri):
y_list = [tri[0][1], tri[1][1], tri[2][1]]
x_list = [tri[0][0], tri[1][0], tri[2][0]]
height = max(y_list) - min(y_list)
width = max(x_list) - min(x_list)
return height * width / 2
empty_sets = [k for k, v in intersections.items() if not len(v)]
if empty_sets:
raise ValueError(f"Given intersections \"{empty_sets}\" are empty, cannot continue")
scatters = []
for k, v in intersections.items():
weights = [area(triangle) for triangle in TRIANGLES[k]]
points_pairs = [point_on_triangle(*random.choices(TRIANGLES[k], weights=weights)[0]) for _ in v]
x, y = zip(*points_pairs)
scatter_labels = [make_label(n) for n in v]
scatters.append(
go.Scatter(
x=x,
y=y,
mode='markers',
showlegend=False,
text=scatter_labels,
marker=dict(
size=10,
line=dict(width=2,
color='DarkSlateGrey'),
opacity=1,
),
hoverinfo="text",
)
)
fig = go.Figure(
data=list(scatters),
layout=go.Layout(
title_text="",
autosize=False,
titlefont_size=16,
showlegend=True,
hovermode='closest',
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, scaleanchor="x", scaleratio=1)
),
)
fig.update_layout(
shapes=[
go.layout.Shape(
type="circle",
x0=0,
y0=-6,
x1=12,
y1=6,
fillcolor="Red",
opacity=0.15,
layer='below'
),
go.layout.Shape(
type="circle",
x0=8,
y0=-6,
x1=20,
y1=6,
fillcolor="Blue",
opacity=0.15,
layer='below'
),
go.layout.Shape(
type="circle",
x0=4,
y0=-12,
x1=16,
y1=0,
fillcolor="Green",
opacity=0.15,
layer='below'
),
]
)
fig.update_layout(
annotations=[
dict(
xref="x",
yref="y",
x=6, y=6,
text=labels[0],
font=dict(size=15),
showarrow=True,
arrowwidth=2,
ax=-50,
ay=-25,
arrowhead=7,
),
dict(
xref="x",
yref="y",
x=14, y=6,
text=labels[1],
font=dict(size=15),
showarrow=True,
arrowwidth=2,
ax=50,
ay=-25,
arrowhead=7,
),
dict(
xref="x",
yref="y",
x=10, y=-12,
text=labels[2],
font=dict(size=15),
showarrow=True,
arrowwidth=2,
ax=50,
ay=25,
arrowhead=7,
),
]
)
fig.show()