def make_bar_chart()

in sig-contributor-experience/surveys/k8s_survey_analysis/plot_utils.py [0:0]


def make_bar_chart(survey_data, topic, facet_by=[], proportional=False):
    """Make a barchart showing the number of respondents listing each 
        column that starts with topic for a single year. If facet_by is
        not empty, the resulting plot will be faceted into subplots 
        by the variables given. 

    Args:
        survey_data (pandas.DataFrame): Raw data read in from Kubernetes Survey   
        topic (str): String that all questions of interest start with 
        facet_by (list,optional): List of columns use for grouping
        proportional (bool, optiona ): Defaults to False. If True,
            the bars heights are determined proportionally to the 
            total number of responses in that facet. 

    Returns:
        (plotnine.ggplot): Plot object which can be displayed in a notebook or saved out to a file
    """
    show_legend = False
    if facet_by:
        show_legend = True

    topic_data_long = get_single_year_data_subset(survey_data, topic, facet_by)

    x = topic_data_long.columns.tolist()
    x.remove("level_1")

    if facet_by:
        period = False
        if "." in facet_by:
            facet_by.remove(".")
            period = True

        aggregate_data = (
            topic_data_long[topic_data_long.rating == 1]
            .dropna()
            .groupby(["level_0"] + facet_by)
            .count()
            .reset_index()
        )

        if period:
            facet_by.append(".")

    else:
        aggregate_data = (
            topic_data_long[topic_data_long.rating == 1]
            .dropna()
            .groupby("level_0")
            .count()
            .reset_index()
        )

    if proportional and facet_by:
        period = False
        if "." in facet_by:
            facet_by.remove(".")
            period = True

        facet_sums = (
            topic_data_long[topic_data_long.rating == 1]
            .dropna()
            .groupby(facet_by)
            .count()
            .reset_index()
        )

        aggregate_data = aggregate_data.merge(facet_sums, on=facet_by).rename(
            columns={"level_0_x": "level_0"}
        )
        aggregate_data = aggregate_data.assign(
            rating=aggregate_data.rating_x / aggregate_data.rating_y
        )

        if period:
            facet_by.append(".")

    br = (
        p9.ggplot(aggregate_data, p9.aes(x="level_0", fill="level_0", y="rating"))
        + p9.geom_bar(show_legend=show_legend, stat="identity")
        + p9.theme(
            axis_text_x=p9.element_text(angle=45, ha="right"),
            strip_text_y=p9.element_text(angle=0, ha="left"),
        )
        + p9.scale_x_discrete(
            limits=topic_data_long["level_0"].unique().tolist(),
            labels=[
                "\n".join(
                    textwrap.wrap(x.replace(topic, "").replace("_", " "), width=35)[0:2]
                )
                for x in topic_data_long["level_0"].unique().tolist()
            ],
        )
    )

    if facet_by:
        br = (
            br
            + p9.facet_grid(
                facet_by, shrink=False, labeller=lambda x: "\n".join(wrap(x, 15))
            )
            + p9.theme(
                axis_text_x=p9.element_blank(),
                strip_text_x=p9.element_text(
                    wrap=True, va="bottom", margin={"b": -0.5}
                ),
            )
            + p9.scale_fill_discrete(
                limits=topic_data_long["level_0"].unique().tolist(),
                labels=[
                    "\n".join(
                        wrap(
                            x.replace(topic, "")
                            .replace("_", " ")
                            .replace("/", "/  ")
                            .strip(),
                            30,
                        )
                    )
                    for x in topic_data_long["level_0"].unique().tolist()
                ],
            )
        )
    return br