def plot()

in causalml/metrics/sensitivity.py [0:0]


    def plot(sens_df, partial_rsqs_df=None, type="raw", ci=False, partial_rsqs=False):
        """Plot the results of a sensitivity analysis against unmeasured
        Args:
            sens_df (pandas.DataFrame): a data frame output from causalsens
            partial_rsqs_d (pandas.DataFrame) : a data frame output from causalsens including partial rsqure
            type (str, optional): the type of plot to draw, 'raw' or 'r.squared' are supported
            ci (bool, optional): whether plot confidence intervals
            partial_rsqs (bool, optional): whether plot partial rsquare results
        """

        if type == "raw" and not ci:
            fig, ax = plt.subplots()
            y_max = round(sens_df["New ATE UB"].max() * 1.1, 4)
            y_min = round(sens_df["New ATE LB"].min() * 0.9, 4)
            x_max = round(sens_df.alpha.max() * 1.1, 4)
            x_min = round(sens_df.alpha.min() * 0.9, 4)
            plt.ylim(y_min, y_max)
            plt.xlim(x_min, x_max)
            ax.plot(sens_df.alpha, sens_df["New ATE"])
        elif type == "raw" and ci:
            fig, ax = plt.subplots()
            y_max = round(sens_df["New ATE UB"].max() * 1.1, 4)
            y_min = round(sens_df["New ATE LB"].min() * 0.9, 4)
            x_max = round(sens_df.alpha.max() * 1.1, 4)
            x_min = round(sens_df.alpha.min() * 0.9, 4)
            plt.ylim(y_min, y_max)
            plt.xlim(x_min, x_max)
            ax.fill_between(
                sens_df.alpha,
                sens_df["New ATE LB"],
                sens_df["New ATE UB"],
                color="gray",
                alpha=0.5,
            )
            ax.plot(sens_df.alpha, sens_df["New ATE"])
        elif type == "r.squared" and ci:
            fig, ax = plt.subplots()
            y_max = round(sens_df["New ATE UB"].max() * 1.1, 4)
            y_min = round(sens_df["New ATE LB"].min() * 0.9, 4)
            plt.ylim(y_min, y_max)
            ax.fill_between(
                sens_df.rsqs,
                sens_df["New ATE LB"],
                sens_df["New ATE UB"],
                color="gray",
                alpha=0.5,
            )
            ax.plot(sens_df.rsqs, sens_df["New ATE"])
            if partial_rsqs:
                plt.scatter(
                    partial_rsqs_df.partial_rsqs,
                    list(sens_df[sens_df.alpha == 0]["New ATE"])
                    * partial_rsqs_df.shape[0],
                    marker="x",
                    color="red",
                    linewidth=10,
                )
        elif type == "r.squared" and not ci:
            fig, ax = plt.subplots()
            y_max = round(sens_df["New ATE UB"].max() * 1.1, 4)
            y_min = round(sens_df["New ATE LB"].min() * 0.9, 4)
            plt.ylim(y_min, y_max)
            plt.plot(sens_df.rsqs, sens_df["New ATE"])
            if partial_rsqs:
                plt.scatter(
                    partial_rsqs_df.partial_rsqs,
                    list(sens_df[sens_df.alpha == 0]["New ATE"])
                    * partial_rsqs_df.shape[0],
                    marker="x",
                    color="red",
                    linewidth=10,
                )