def xy_dataset()

in dowhy/datasets.py [0:0]


def xy_dataset(num_samples, effect=True,
        num_common_causes=1,
        is_linear=True,
        sd_error=1):
    treatment = 'Treatment'
    outcome = 'Outcome'
    common_causes = ['w'+str(i) for i in range(num_common_causes)]
    time_var = 's'
    # Error terms
    E1 = np.random.normal(loc=0, scale=sd_error, size=num_samples)
    E2 = np.random.normal(loc=0, scale=sd_error, size=num_samples)

    S = np.random.uniform(0, 10, num_samples)
    T1 = 4 - (S - 3) * (S - 3)
    T1[S >= 5] = 0
    T2 = (S - 7) * (S - 7) - 4
    T2[S <= 5] = 0
    W0 = T1 + T2  # hidden confounder
    tterm, yterm = 0,0
    if num_common_causes > 1:
        means = np.random.uniform(-1, 1, num_common_causes-1)
        cov_mat = np.diag(np.ones(num_common_causes-1))
        otherW = np.random.multivariate_normal(means, cov_mat, num_samples)
        c1 = np.random.uniform(0, 1, (otherW.shape[1], 1))
        c2 = np.random.uniform(0, 1, (otherW.shape[1], 1))
        tterm = (otherW @ c1)[:,0]
        yterm = (otherW @ c2)[:,0]

    if is_linear:
        V = 6 + W0 + tterm +  E1
        Y = 6 + W0 + yterm + E2  # + (V-8)*(V-8)
        if effect:
            Y += V
        else:
            Y += (6 + W0)
    else:
        V = 6 + W0*W0 + tterm +  E1
        Y = 6 + W0*W0 + yterm + E2  # + (V-8)*(V-8)
        if effect:
            Y += V #/20 # divide by 10 to scale the value of Y to be comparable to V
        else:
            Y += (6 + W0)
    #else:
    #    V = 6 + W0 + tterm + E1
    #    Y = 12 + W0*W0 + W0*W0 + yterm + E2  # E2_new
    dat = {
        treatment: V,
        outcome: Y,
        common_causes[0]: W0,
        time_var: S
    }
    if num_common_causes > 1:
        for i in range(otherW.shape[1]):
            dat[common_causes[i+1]] = otherW[:,i]
    data = pd.DataFrame(data=dat)
    ret_dict = {
        "df": data,
        "treatment_name": treatment,
        "outcome_name": outcome,
        "common_causes_names": common_causes,
        "time_val": time_var,
        "instrument_names": None,
        "dot_graph": None,
        "gml_graph": None,
        "ate": None,
    }
    return ret_dict