def distr_plot_single_sim()

in causalml/dataset/synthetic.py [0:0]


def distr_plot_single_sim(synthetic_preds, kind='kde', drop_learners=[], bins=50, histtype='step', alpha=1, linewidth=1,
               bw_method=1):
    """Plots the distribution of each learner's predictions (for a single simulation).
    Kernel Density Estimation (kde) and actual histogram plots supported.

    Args:
        synthetic_preds (dict): dictionary of predictions generated by get_synthetic_preds()
        kind (str, optional): 'kde' or 'hist'
        drop_learners (list, optional): list of learners (str) to omit when plotting
        bins (int, optional): number of bins to plot if kind set to 'hist'
        histtype (str, optional): histogram type if kind set to 'hist'
        alpha (float, optional): alpha (transparency) for plotting
        linewidth (int, optional): line width for plotting
        bw_method (float, optional): parameter for kde
    """
    preds_for_plot = synthetic_preds.copy()

    # deleted generated data and assign actual value
    del preds_for_plot[KEY_GENERATED_DATA]
    global_lower = np.percentile(np.hstack(preds_for_plot.values()), 1)
    global_upper = np.percentile(np.hstack(preds_for_plot.values()), 99)
    learners = list(preds_for_plot.keys())
    learners = [learner for learner in learners if learner not in drop_learners]

    # Plotting
    plt.figure(figsize=(12, 8))
    colors = ['black', 'red', 'blue', 'green', 'cyan', 'brown', 'grey', 'pink', 'orange', 'yellow']
    for i, (k, v) in enumerate(preds_for_plot.items()):
        if k in learners:
            if kind == 'kde':
                v = pd.Series(v.flatten())
                v = v[v.between(global_lower, global_upper)]
                v.plot(kind='kde', bw_method=bw_method, label=k, linewidth=linewidth, color=colors[i])
            elif kind == 'hist':
                plt.hist(v, bins=np.linspace(global_lower, global_upper, bins), label=k, histtype=histtype,
                         alpha=alpha, linewidth=linewidth, color=colors[i])
            else:
                pass

    plt.xlim(global_lower, global_upper)
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.title('Distribution from a Single Simulation')