def get_synthetic_summary_holdout()

in causalml/dataset/synthetic.py [0:0]


def get_synthetic_summary_holdout(synthetic_data_func, n=1000, valid_size=0.2, k=1):
    """Generate a summary for predictions on synthetic data for train and holdout using specified function

    Args:
        synthetic_data_func (function): synthetic data generation function
        n (int, optional): number of samples per simulation
        valid_size(float,optional): validation/hold out data size
        k (int, optional): number of simulations


    Returns:
        (tuple): summary evaluation metrics of predictions for train and validation:

          - summary_train (pandas.DataFrame): training data evaluation summary
          - summary_train (pandas.DataFrame): validation data evaluation summary
    """

    summaries_train = []
    summaries_validation = []

    for i in range(k):
        preds_dict_train, preds_dict_valid = get_synthetic_preds_holdout(synthetic_data_func, n=n,
                                                                         valid_size=valid_size)
        actuals_train = preds_dict_train[KEY_ACTUAL]
        actuals_validation = preds_dict_valid[KEY_ACTUAL]

        synthetic_summary_train = pd.DataFrame({label: [preds.mean(), mse(preds, actuals_train)] for label, preds
                                                in preds_dict_train.items() if KEY_GENERATED_DATA not in label.lower()},
                                               index=['ATE', 'MSE']).T
        synthetic_summary_train['Abs % Error of ATE'] = np.abs(
            (synthetic_summary_train['ATE']/synthetic_summary_train.loc[KEY_ACTUAL, 'ATE']) - 1)

        synthetic_summary_validation = pd.DataFrame({label: [preds.mean(), mse(preds, actuals_validation)]
                                                     for label, preds in preds_dict_valid.items()
                                                     if KEY_GENERATED_DATA not in label.lower()},
                                                    index=['ATE', 'MSE']).T
        synthetic_summary_validation['Abs % Error of ATE'] = np.abs(
            (synthetic_summary_validation['ATE']/synthetic_summary_validation.loc[KEY_ACTUAL, 'ATE']) - 1)

        # calculate kl divergence for training
        for label in synthetic_summary_train.index:
            stacked_values = np.hstack((preds_dict_train[label], actuals_train))
            stacked_low = np.percentile(stacked_values, 0.1)
            stacked_high = np.percentile(stacked_values, 99.9)
            bins = np.linspace(stacked_low, stacked_high, 100)

            distr = np.histogram(preds_dict_train[label], bins=bins)[0]
            distr = np.clip(distr/distr.sum(), 0.001, 0.999)
            true_distr = np.histogram(actuals_train, bins=bins)[0]
            true_distr = np.clip(true_distr/true_distr.sum(), 0.001, 0.999)

            kl = entropy(distr, true_distr)
            synthetic_summary_train.loc[label, 'KL Divergence'] = kl

        # calculate kl divergence for validation
        for label in synthetic_summary_validation.index:
            stacked_values = np.hstack((preds_dict_valid[label], actuals_validation))
            stacked_low = np.percentile(stacked_values, 0.1)
            stacked_high = np.percentile(stacked_values, 99.9)
            bins = np.linspace(stacked_low, stacked_high, 100)

            distr = np.histogram(preds_dict_valid[label], bins=bins)[0]
            distr = np.clip(distr/distr.sum(), 0.001, 0.999)
            true_distr = np.histogram(actuals_validation, bins=bins)[0]
            true_distr = np.clip(true_distr/true_distr.sum(), 0.001, 0.999)

            kl = entropy(distr, true_distr)
            synthetic_summary_validation.loc[label, 'KL Divergence'] = kl

        summaries_train.append(synthetic_summary_train)
        summaries_validation.append(synthetic_summary_validation)

    summary_train = sum(summaries_train) / k
    summary_validation = sum(summaries_validation) / k
    return (summary_train[['Abs % Error of ATE', 'MSE', 'KL Divergence']],
            summary_validation[['Abs % Error of ATE', 'MSE', 'KL Divergence']])