in dowhy/causal_estimators/distance_matching_estimator.py [0:0]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check if the treatment is one-dimensional
if len(self._treatment_name) > 1:
error_msg = str(self.__class__) + "cannot handle more than one treatment variable"
raise Exception(error_msg)
# Checking if the treatment is binary
if not pd.api.types.is_bool_dtype(self._data[self._treatment_name[0]]):
error_msg = "Distance Matching method is applicable only for binary treatments"
self.logger.error(error_msg)
raise Exception(error_msg)
# Setting the number of matches per data point
if getattr(self, 'num_matches_per_unit', None) is None:
self.num_matches_per_unit = 1
# Default distance metric if not provided by the user
if getattr(self, 'distance_metric', None) is None:
self.distance_metric = 'minkowski' # corresponds to euclidean metric with p=2
if getattr(self, 'exact_match_cols', None) is None:
self.exact_match_cols = None
self.logger.debug("Back-door variables used:" +
",".join(self._target_estimand.get_backdoor_variables()))
self._observed_common_causes_names = self._target_estimand.get_backdoor_variables()
if self._observed_common_causes_names:
if self.exact_match_cols is not None:
self._observed_common_causes_names = [v for v in self._observed_common_causes_names if v not in self.exact_match_cols]
self._observed_common_causes = self._data[self._observed_common_causes_names]
# Convert the categorical variables into dummy/indicator variables
# Basically, this gives a one hot encoding for each category
# The first category is taken to be the base line.
self._observed_common_causes = pd.get_dummies(self._observed_common_causes, drop_first=True)
else:
self._observed_common_causes = None
error_msg = "No common causes/confounders present. Distance matching methods are not applicable"
self.logger.error(error_msg)
raise Exception(error_msg)
# Dictionary of any user-provided params for the distance metric
# that will be passed to sklearn nearestneighbors
self.distance_metric_params = {}
for param_name in self.Valid_Dist_Metric_Params:
param_val = getattr(self, param_name, None)
if param_val is not None:
self.distance_metric_params[param_name] = param_val
self.logger.info("INFO: Using Distance Matching Estimator")
self.symbolic_estimator = self.construct_symbolic_estimator(self._target_estimand)
self.logger.info(self.symbolic_estimator)
self.matched_indices_att = None
self.matched_indices_atc = None