def get_synthetic_preds()

in causalml/dataset/synthetic.py [0:0]


def get_synthetic_preds(synthetic_data_func, n=1000, estimators={}):
    """Generate predictions for synthetic data using specified function (single simulation)

    Args:
        synthetic_data_func (function): synthetic data generation function
        n (int, optional): number of samples
        estimators (dict of object): dict of names and objects of treatment effect estimators

    Returns:
        (dict): dict of the actual and estimates of treatment effects
    """
    y, X, w, tau, b, e = synthetic_data_func(n=n)

    preds_dict = {}
    preds_dict[KEY_ACTUAL] = tau
    preds_dict[KEY_GENERATED_DATA] = {
        "y": y,
        "X": X,
        "w": w,
        "tau": tau,
        "b": b,
        "e": e,
    }

    # Predict p_hat because e would not be directly observed in real-life
    p_model = ElasticNetPropensityModel()
    p_hat = p_model.fit_predict(X, w)

    if estimators:
        for name, learner in estimators.items():
            try:
                preds_dict[name] = learner.fit_predict(
                    X=X, treatment=w, y=y, p=p_hat
                ).flatten()
            except TypeError:
                preds_dict[name] = learner.fit_predict(X=X, treatment=w, y=y).flatten()
    else:
        for base_learner, label_l in zip(
            [BaseSRegressor, BaseTRegressor, BaseXRegressor, BaseRRegressor],
            ["S", "T", "X", "R"],
        ):
            for model, label_m in zip([LinearRegression, XGBRegressor], ["LR", "XGB"]):
                learner = base_learner(model())
                model_name = "{} Learner ({})".format(label_l, label_m)
                try:
                    preds_dict[model_name] = learner.fit_predict(
                        X=X, treatment=w, y=y, p=p_hat
                    ).flatten()
                except TypeError:
                    preds_dict[model_name] = learner.fit_predict(
                        X=X, treatment=w, y=y
                    ).flatten()

        learner = CausalTreeRegressor(random_state=RANDOM_SEED)
        preds_dict["Causal Tree"] = learner.fit_predict(X=X, treatment=w, y=y).flatten()

    return preds_dict