def _estimate_effect()

in dowhy/causal_estimators/distance_matching_estimator.py [0:0]


    def _estimate_effect(self):
        # this assumes a binary treatment regime
        updated_df = pd.concat([self._observed_common_causes,
            self._data[[self._outcome_name, self._treatment_name[0]]]], axis=1)
        if self.exact_match_cols is not None:
            updated_df = pd.concat([updated_df, self._data[self.exact_match_cols]], axis=1)
        treated = updated_df.loc[self._data[self._treatment_name[0]] == 1]
        control = updated_df.loc[self._data[self._treatment_name[0]] == 0]
        numtreatedunits = treated.shape[0]
        numcontrolunits = control.shape[0]

        fit_att, fit_atc = False, False
        est = None
        # TODO remove neighbors that are more than a given radius apart
        if self._target_units == "att":
            fit_att = True
        elif self._target_units == "atc":
            fit_atc = True
        elif self._target_units == "ate":
            fit_att = True
            fit_atc = True
        else:
            raise ValueError("Target units string value not supported")

        if fit_att:
            # estimate ATT on treated by summing over difference between matched neighbors
            if self.exact_match_cols is None:
                control_neighbors = (
                    NearestNeighbors(n_neighbors=self.num_matches_per_unit,
                        metric=self.distance_metric,
                        algorithm='ball_tree',
                        **self.distance_metric_params)
                    .fit(control[self._observed_common_causes.columns].values)
                )
                distances, indices = control_neighbors.kneighbors(
                        treated[self._observed_common_causes.columns].values)
                self.logger.debug("distances:")
                self.logger.debug(distances)

                att = 0

                for i in range(numtreatedunits):
                    treated_outcome = treated.iloc[i][self._outcome_name].item()
                    control_outcome = np.mean(control.iloc[indices[i]][self._outcome_name].values)
                    att += treated_outcome - control_outcome

                att /= numtreatedunits
                if self._target_units == "att":
                    est = att
                elif self._target_units == "ate":
                    est = att*numtreatedunits

                # Return indices in the original dataframe
                self.matched_indices_att = {}
                treated_df_index = treated.index.tolist()
                for i in range(numtreatedunits):
                    self.matched_indices_att[treated_df_index[i]] = control.iloc[indices[i]].index.tolist()
            else:
                grouped = updated_df.groupby(self.exact_match_cols)
                att = 0
                for name, group in grouped:
                    treated = group.loc[group[self._treatment_name[0]] == 1]
                    control = group.loc[group[self._treatment_name[0]] == 0]
                    if treated.shape[0] == 0:
                        continue
                    control_neighbors = (
                        NearestNeighbors(n_neighbors=self.num_matches_per_unit,
                            metric=self.distance_metric,
                            algorithm='ball_tree',
                            **self.distance_metric_params)
                        .fit(control[self._observed_common_causes.columns].values)
                    )
                    distances, indices = control_neighbors.kneighbors(
                            treated[self._observed_common_causes.columns].values)
                    self.logger.debug("distances:")
                    self.logger.debug(distances)

                    for i in range(numtreatedunits):
                        treated_outcome = treated.iloc[i][self._outcome_name].item()
                        control_outcome = np.mean(control.iloc[indices[i]][self._outcome_name].values)
                        att += treated_outcome - control_outcome
                        #self.matched_indices_att[treated_df_index[i]] = control.iloc[indices[i]].index.tolist()

                att /= numtreatedunits

                if self._target_units == "att":
                    est = att
                elif self._target_units == "ate":
                    est = att*numtreatedunits

        if fit_atc:
            #Now computing ATC
            treated_neighbors = (
                NearestNeighbors(n_neighbors=self.num_matches_per_unit,
                    metric=self.distance_metric,
                    algorithm='ball_tree',
                    **self.distance_metric_params)
                .fit(treated[self._observed_common_causes.columns].values)
            )
            distances, indices = treated_neighbors.kneighbors(
                    control[self._observed_common_causes.columns].values)

            atc = 0
            for i in range(numcontrolunits):
                control_outcome = control.iloc[i][self._outcome_name].item()
                treated_outcome = np.mean(treated.iloc[indices[i]][self._outcome_name].values)
                atc += treated_outcome - control_outcome

            atc /= numcontrolunits

            if self._target_units == "atc":
                est = atc
            elif self._target_units == "ate":
                est += atc*numcontrolunits
                est /= (numtreatedunits+numcontrolunits)

            # Return indices in the original dataframe
            self.matched_indices_atc = {}
            control_df_index = control.index.tolist()
            for i in range(numcontrolunits):
                self.matched_indices_atc[control_df_index[i]] = treated.iloc[indices[i]].index.tolist()

        estimate = CausalEstimate(estimate=est,
                                  control_value=self._control_value,
                                  treatment_value=self._treatment_value,
                                  target_estimand=self._target_estimand,
                                  realized_estimand_expr=self.symbolic_estimator)
        return estimate