def draw_referenced_subgraph()

in identity-resolution/notebooks/identity-graph/nepytune/usecase/similar_audience.py [0:0]


def draw_referenced_subgraph(g, website_url, categories_limit=3, search_time_limit_in_seconds=15):
    average_profile = _get_categories_popular_across_audience_of_website(
        g, website_url, categories_limit=categories_limit
    ).toList()
    average_profile = dict(
        chain(*category.items()) for category in average_profile
    )
    similar_audience = _query_users_activities_stats(
        g, website_url, average_profile, search_time_limit_in_seconds=search_time_limit_in_seconds
    )
    similar_audience = similar_audience.limit(15).toList()

    graph = _build_graph(average_profile, similar_audience)

    iabs = [n for n, params in graph.nodes(data=True) if params["label"] == "IAB"]
    avg_iabs = [n for n in iabs if graph.node[n]["category"] in average_profile]

    graph_with_pos_computed = drawing.layout(
        graph,
        nx.shell_layout,
        nlist=[
            ["averageBuyer"],
            avg_iabs,
            set(iabs) - set(avg_iabs),
            [n for n, params in graph.nodes(data=True) if params["label"] == "persistentId"],
            [n for n, params in graph.nodes(data=True) if params["label"] == "transientId"],
        ]
    )

    # update positions
    for name in set(iabs) - set(avg_iabs):
        node = graph_with_pos_computed.node[name]
        node["pos"] = [node["pos"][0], node["pos"][1]-1.75]

    for name in ["averageBuyer"] + avg_iabs:
        node = graph_with_pos_computed.node[name]
        node["pos"] = [node["pos"][0], node["pos"][1]+1.75]

    node = graph_with_pos_computed.node["averageBuyer"]
    node["pos"] = [node["pos"][0], node["pos"][1]+1]

    drawing.draw(
        title="User devices that visited ecommerce websites and optionally converted",
        scatters=list(
            drawing.edge_scatters_by_label(
                graph_with_pos_computed,
                dashes={
                    "interestedInButNotSufficient": "dash",
                    "interestedIn": "solid"
                }
            )) + list(
            drawing.scatters_by_label(
                graph_with_pos_computed, attrs_to_skip=["pos", "opacity"],
                sizes={
                    "averageBuyer": 30,
                    "IAB":10,
                    "persistentId":20
                }
            )
        )
    )