def make_formula()

in src/mozanalysis/frequentist_stats/linear_models/functions.py [0:0]


def make_formula(target: str, ref_branch: str, covariate: str | None = None) -> str:
    """Makes a formula which defines the model to build. Includes terms for
    treatment branch and, optionally, a main effect for a (generally pre-experiment)
    covariate.

    Ex: given `target` of 'Y' and `ref_branch` of control, the formula will be:
    `Y ~ C(branch, Treatment(reference='control'))`. That is, we want to predict `Y`
    with treatment branch (`branch`) treated as a categorical predictor with
    reference level of `'control'`.

    This builds the following linear model (assuming 2 other treatment branches, t1
    and t2): `Y_i = \beta_0 + \beta_1*I(branch == 't1') + \beta_2*I(branch == 't2')`
    where `I` is the indicator function.

    Inferences on `\beta_1` are inferences of the average treatment effect (ATE) of
    branch t1. That is, the confidence interval for `\beta_1` is the confidence
    interval for the ATE of branch t1. Similarly for t2 and `\beta_2`.

    We can incorporate a single (generally pre-experiment) covariate. Adding a covariate
    of `Ypre` to the above, we'll build the following formula:
    `Y ~ C(branch, Treatment(reference='control')) + Ypre`
    which will fit the following linear model:

    `Y_i = \beta_0 + \beta_1*I(branch == 't1') + \beta_2*I(branch == 't2')
    + \beta_3* Ypre_i`

    For now, we elect to not include branch by covariate interaction terms and instead
    to perform inferences only on the population-level ATE.

    Parameters:
    - target (str): the variable of interest.
    - ref_branch (str): the name of the reference branch
    - covariate (Optional[str]): the name of a covariate to include in the model.

    Returns:
    - formula (str): the R-style formula, to be passed to statsmodels's formula API.

    """
    pattern = re.compile(r"(\(|\)|\~|\')")
    if pattern.findall(target):
        raise ValueError(f"Target variable {target} contains invalid character")
    if pattern.findall(ref_branch):
        raise ValueError(f"Reference branch {ref_branch} contains invalid character")
    if covariate is not None and pattern.findall(covariate):
        raise ValueError(f"Covariate {covariate} contains invalid character")

    formula = f"{target} ~ C(branch, Treatment(reference='{ref_branch}'))"
    if covariate is not None:
        formula += f" + {covariate}"

    return formula