def get_synthetic_preds_holdout()

in causalml/dataset/synthetic.py [0:0]


def get_synthetic_preds_holdout(synthetic_data_func, n=1000, valid_size=0.2,
                                estimators={}):
    """Generate predictions for synthetic data using specified function (single simulation) for train and holdout

    Args:
        synthetic_data_func (function): synthetic data generation function
        n (int, optional): number of samples
        valid_size(float,optional): validaiton/hold out data size
        estimators (dict of object): dict of names and objects of treatment effect estimators

    Returns:
        (tuple): synthetic training and validation data dictionaries:

          - preds_dict_train (dict): synthetic training data dictionary
          - preds_dict_valid (dict): synthetic validation data dictionary
    """
    y, X, w, tau, b, e = synthetic_data_func(n=n)

    X_train, X_val, y_train, y_val, w_train, w_val, tau_train, tau_val, b_train, b_val, e_train, e_val = \
        train_test_split(X, y, w, tau, b, e, test_size=valid_size, random_state=RANDOM_SEED, shuffle=True)

    preds_dict_train = {}
    preds_dict_valid = {}

    preds_dict_train[KEY_ACTUAL] = tau_train
    preds_dict_valid[KEY_ACTUAL] = tau_val

    preds_dict_train['generated_data'] = {
        'y': y_train,
        'X': X_train,
        'w': w_train,
        'tau': tau_train,
        'b': b_train,
        'e': e_train}
    preds_dict_valid['generated_data'] = {
        'y': y_val,
        'X': X_val,
        'w': w_val,
        'tau': tau_val,
        'b': b_val,
        'e': e_val}

    # Predict p_hat because e would not be directly observed in real-life
    p_model = ElasticNetPropensityModel()
    p_hat_train = p_model.fit_predict(X_train, w_train)
    p_hat_val = p_model.fit_predict(X_val, w_val)

    for base_learner, label_l in zip([BaseSRegressor, BaseTRegressor, BaseXRegressor, BaseRRegressor],
                                     ['S', 'T', 'X', 'R']):
        for model, label_m in zip([LinearRegression, XGBRegressor], ['LR', 'XGB']):
            # RLearner will need to fit on the p_hat
            if label_l != 'R':
                learner = base_learner(model())
                # fit the model on training data only
                learner.fit(X=X_train, treatment=w_train, y=y_train)
                try:
                    preds_dict_train['{} Learner ({})'.format(
                        label_l, label_m)] = learner.predict(X=X_train, p=p_hat_train).flatten()
                    preds_dict_valid['{} Learner ({})'.format(
                        label_l, label_m)] = learner.predict(X=X_val, p=p_hat_val).flatten()
                except TypeError:
                    preds_dict_train['{} Learner ({})'.format(
                        label_l, label_m)] = learner.predict(X=X_train, treatment=w_train, y=y_train).flatten()
                    preds_dict_valid['{} Learner ({})'.format(
                        label_l, label_m)] = learner.predict(X=X_val, treatment=w_val, y=y_val).flatten()
            else:
                learner = base_learner(model())
                learner.fit(X=X_train, p=p_hat_train, treatment=w_train, y=y_train)
                preds_dict_train['{} Learner ({})'.format(
                    label_l, label_m)] = learner.predict(X=X_train).flatten()
                preds_dict_valid['{} Learner ({})'.format(
                    label_l, label_m)] = learner.predict(X=X_val).flatten()

    return preds_dict_train, preds_dict_valid