def plot_tmlegain()

in causalml/metrics/visualize.py [0:0]


def plot_tmlegain(df, inference_col, learner=LGBMRegressor(num_leaves=64, learning_rate=.05, n_estimators=300),
                  outcome_col='y', treatment_col='w', p_col='tau', n_segment=5, cv=None,
                  calibrate_propensity=True, ci=False, figsize=(8, 8)):
    """Plot the lift chart based of TMLE estimation

    Args:
        df (pandas.DataFrame): a data frame with model estimates and actual data as columns
        inferenece_col (list of str): a list of columns that used in learner for inference
        learner (optional): a model used by TMLE to estimate the outcome
        outcome_col (str, optional): the column name for the actual outcome
        treatment_col (str, optional): the column name for the treatment indicator (0 or 1)
        p_col (str, optional): the column name for propensity score
        n_segment (int, optional): number of segment that TMLE will estimated for each
        cv (sklearn.model_selection._BaseKFold, optional): sklearn CV object
        calibrate_propensity (bool, optional): whether calibrate propensity score or not
        ci (bool, optional): whether return confidence intervals for ATE or not
    """
    plot_df = get_tmlegain(df, learner=learner, inference_col=inference_col, outcome_col=outcome_col,
                           treatment_col=treatment_col, p_col=p_col, n_segment=n_segment, cv=cv,
                           calibrate_propensity=calibrate_propensity, ci=ci)
    if ci:
        model_names = [x.replace(" LB", "") for x in plot_df.columns]
        model_names = list(set([x.replace(" UB", "") for x in model_names]))

        fig, ax = plt.subplots(figsize=figsize)
        cmap = plt.get_cmap("tab10")
        cindex = 0

        for col in model_names:
            lb_col = col + " LB"
            up_col = col + " UB"

            if col != 'Random':
                ax.plot(plot_df.index, plot_df[col], color=cmap(cindex))
                ax.fill_between(plot_df.index, plot_df[lb_col], plot_df[up_col], color=cmap(cindex), alpha=0.25)
            else:
                ax.plot(plot_df.index, plot_df[col], color=cmap(cindex))
            cindex += 1

        ax.legend()
    else:
        plot_df.plot(figsize=figsize)

    plt.xlabel('Population')
    plt.ylabel('Gain')
    plt.show()