def get_synthetic_summary_holdout()

in causalml/dataset/synthetic.py [0:0]


def get_synthetic_summary_holdout(synthetic_data_func, n=1000, valid_size=0.2, k=1):
    """Generate a summary for predictions on synthetic data for train and holdout using specified function

    Args:
        synthetic_data_func (function): synthetic data generation function
        n (int, optional): number of samples per simulation
        valid_size(float,optional): validation/hold out data size
        k (int, optional): number of simulations


    Returns:
        (tuple): summary evaluation metrics of predictions for train and validation:

          - summary_train (pandas.DataFrame): training data evaluation summary
          - summary_train (pandas.DataFrame): validation data evaluation summary
    """

    summaries_train = []
    summaries_validation = []

    for i in range(k):
        preds_dict_train, preds_dict_valid = get_synthetic_preds_holdout(
            synthetic_data_func, n=n, valid_size=valid_size
        )
        actuals_train = preds_dict_train[KEY_ACTUAL]
        actuals_validation = preds_dict_valid[KEY_ACTUAL]

        synthetic_summary_train = pd.DataFrame(
            {
                label: [preds.mean(), mse(preds, actuals_train)]
                for label, preds in preds_dict_train.items()
                if KEY_GENERATED_DATA not in label.lower()
            },
            index=["ATE", "MSE"],
        ).T
        synthetic_summary_train["Abs % Error of ATE"] = np.abs(
            (
                synthetic_summary_train["ATE"]
                / synthetic_summary_train.loc[KEY_ACTUAL, "ATE"]
            )
            - 1
        )

        synthetic_summary_validation = pd.DataFrame(
            {
                label: [preds.mean(), mse(preds, actuals_validation)]
                for label, preds in preds_dict_valid.items()
                if KEY_GENERATED_DATA not in label.lower()
            },
            index=["ATE", "MSE"],
        ).T
        synthetic_summary_validation["Abs % Error of ATE"] = np.abs(
            (
                synthetic_summary_validation["ATE"]
                / synthetic_summary_validation.loc[KEY_ACTUAL, "ATE"]
            )
            - 1
        )

        # calculate kl divergence for training
        for label in synthetic_summary_train.index:
            stacked_values = np.hstack((preds_dict_train[label], actuals_train))
            stacked_low = np.percentile(stacked_values, 0.1)
            stacked_high = np.percentile(stacked_values, 99.9)
            bins = np.linspace(stacked_low, stacked_high, 100)

            distr = np.histogram(preds_dict_train[label], bins=bins)[0]
            distr = np.clip(distr / distr.sum(), 0.001, 0.999)
            true_distr = np.histogram(actuals_train, bins=bins)[0]
            true_distr = np.clip(true_distr / true_distr.sum(), 0.001, 0.999)

            kl = entropy(distr, true_distr)
            synthetic_summary_train.loc[label, "KL Divergence"] = kl

        # calculate kl divergence for validation
        for label in synthetic_summary_validation.index:
            stacked_values = np.hstack((preds_dict_valid[label], actuals_validation))
            stacked_low = np.percentile(stacked_values, 0.1)
            stacked_high = np.percentile(stacked_values, 99.9)
            bins = np.linspace(stacked_low, stacked_high, 100)

            distr = np.histogram(preds_dict_valid[label], bins=bins)[0]
            distr = np.clip(distr / distr.sum(), 0.001, 0.999)
            true_distr = np.histogram(actuals_validation, bins=bins)[0]
            true_distr = np.clip(true_distr / true_distr.sum(), 0.001, 0.999)

            kl = entropy(distr, true_distr)
            synthetic_summary_validation.loc[label, "KL Divergence"] = kl

        summaries_train.append(synthetic_summary_train)
        summaries_validation.append(synthetic_summary_validation)

    summary_train = sum(summaries_train) / k
    summary_validation = sum(summaries_validation) / k
    return (
        summary_train[["Abs % Error of ATE", "MSE", "KL Divergence"]],
        summary_validation[["Abs % Error of ATE", "MSE", "KL Divergence"]],
    )