def plot()

in causalml/metrics/sensitivity.py [0:0]


    def plot(self, sens_df, partial_rsqs_df=None, type='raw', ci=False, partial_rsqs=False):
        """Plot the results of a sensitivity analysis against unmeasured
        Args:
            sens_df (pandas.DataFrame): a data frame output from causalsens
            partial_rsqs_d (pandas.DataFrame) : a data frame output from causalsens including partial rsqure
            type (str, optional): the type of plot to draw, 'raw' or 'r.squared' are supported
            ci (bool, optional): whether plot confidence intervals
            partial_rsqs (bool, optional): whether plot partial rsquare results
         """

        if type == 'raw' and not ci:
            fig, ax = plt.subplots()
            y_max = round(sens_df['New ATE UB'].max()*1.1, 4)
            y_min = round(sens_df['New ATE LB'].min()*0.9, 4)
            x_max = round(sens_df.alpha.max()*1.1, 4)
            x_min = round(sens_df.alpha.min()*0.9, 4)
            plt.ylim(y_min, y_max)
            plt.xlim(x_min, x_max)
            ax.plot(sens_df.alpha, sens_df['New ATE'])
        elif type == 'raw' and ci:
            fig, ax = plt.subplots()
            y_max = round(sens_df['New ATE UB'].max()*1.1, 4)
            y_min = round(sens_df['New ATE LB'].min()*0.9, 4)
            x_max = round(sens_df.alpha.max()*1.1, 4)
            x_min = round(sens_df.alpha.min()*0.9, 4)
            plt.ylim(y_min, y_max)
            plt.xlim(x_min, x_max)
            ax.fill_between(sens_df.alpha, sens_df['New ATE LB'], sens_df['New ATE UB'], color='gray', alpha=0.5)
            ax.plot(sens_df.alpha, sens_df['New ATE'])
        elif type == 'r.squared' and ci:
            fig, ax = plt.subplots()
            y_max = round(sens_df['New ATE UB'].max()*1.1, 4)
            y_min = round(sens_df['New ATE LB'].min()*0.9, 4)
            plt.ylim(y_min, y_max)
            ax.fill_between(sens_df.rsqs, sens_df['New ATE LB'], sens_df['New ATE UB'], color='gray', alpha=0.5)
            ax.plot(sens_df.rsqs, sens_df['New ATE'])
            if partial_rsqs:
                plt.scatter(partial_rsqs_df.partial_rsqs,
                        list(sens_df[sens_df.alpha == 0]['New ATE']) * partial_rsqs_df.shape[0],
                        marker='x', color="red", linewidth=10)
        elif type == 'r.squared' and not ci:
            fig, ax = plt.subplots()
            y_max = round(sens_df['New ATE UB'].max()*1.1, 4)
            y_min = round(sens_df['New ATE LB'].min()*0.9, 4)
            plt.ylim(y_min, y_max)
            plt.plot(sens_df.rsqs, sens_df['New ATE'])
            if partial_rsqs:
                plt.scatter(partial_rsqs_df.partial_rsqs,
                        list(sens_df[sens_df.alpha == 0]['New ATE']) * partial_rsqs_df.shape[0],
                        marker='x', color="red", linewidth=10)