def __init__()

in dowhy/causal_estimators/distance_matching_estimator.py [0:0]


    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Check if the treatment is one-dimensional
        if len(self._treatment_name) > 1:
            error_msg = str(self.__class__) + "cannot handle more than one treatment variable"
            raise Exception(error_msg)
        # Checking if the treatment is binary
        if not pd.api.types.is_bool_dtype(self._data[self._treatment_name[0]]):
            error_msg = "Distance Matching method is applicable only for binary treatments"
            self.logger.error(error_msg)
            raise Exception(error_msg)

        # Setting the number of matches per data point
        if getattr(self, 'num_matches_per_unit', None) is None:
            self.num_matches_per_unit = 1
        # Default distance metric if not provided by the user
        if getattr(self, 'distance_metric', None) is None:
            self.distance_metric = 'minkowski' # corresponds to euclidean metric with p=2

        if getattr(self, 'exact_match_cols', None) is None:
            self.exact_match_cols = None

        self.logger.debug("Back-door variables used:" +
                        ",".join(self._target_estimand.get_backdoor_variables()))

        self._observed_common_causes_names = self._target_estimand.get_backdoor_variables()
        if self._observed_common_causes_names:
            if self.exact_match_cols is not None:
                self._observed_common_causes_names = [v for v in self._observed_common_causes_names if v not in self.exact_match_cols]
            self._observed_common_causes = self._data[self._observed_common_causes_names]
            # Convert the categorical variables into dummy/indicator variables
            # Basically, this gives a one hot encoding for each category
            # The first category is taken to be the base line.
            self._observed_common_causes = pd.get_dummies(self._observed_common_causes, drop_first=True)
        else:
            self._observed_common_causes = None
            error_msg = "No common causes/confounders present. Distance matching methods are not applicable"
            self.logger.error(error_msg)
            raise Exception(error_msg)


        # Dictionary of any user-provided params for the distance metric
        # that will be passed to sklearn nearestneighbors
        self.distance_metric_params = {}
        for param_name in self.Valid_Dist_Metric_Params:
            param_val = getattr(self, param_name, None)
            if param_val is not None:
                self.distance_metric_params[param_name] = param_val


        self.logger.info("INFO: Using Distance Matching Estimator")
        self.symbolic_estimator = self.construct_symbolic_estimator(self._target_estimand)
        self.logger.info(self.symbolic_estimator)
        self.matched_indices_att = None
        self.matched_indices_atc = None