in causalml/match.py [0:0]
def match(self, data, treatment_col, score_cols):
"""Find matches from the control group by matching on specified columns
(propensity preferred).
Args:
data (pandas.DataFrame): total input data
treatment_col (str): the column name for the treatment
score_cols (list): list of column names for matching (propensity
column should be included)
Returns:
(pandas.DataFrame): The subset of data consisting of matched
treatment and control group data.
"""
assert isinstance(score_cols, list), "score_cols must be a list"
treatment = data.loc[data[treatment_col] == 1, score_cols]
control = data.loc[data[treatment_col] == 0, score_cols]
sdcal = self.caliper * np.std(data[score_cols].values)
if self.replace:
scaler = StandardScaler()
scaler.fit(data[score_cols])
treatment_scaled = pd.DataFrame(
scaler.transform(treatment), index=treatment.index
)
control_scaled = pd.DataFrame(
scaler.transform(control), index=control.index
)
# SD is the same as caliper because we use a StandardScaler above
sdcal = self.caliper
matching_model = NearestNeighbors(
n_neighbors=self.ratio, n_jobs=self.n_jobs
)
matching_model.fit(control_scaled)
distances, indices = matching_model.kneighbors(treatment_scaled)
# distances and indices are (n_obs, self.ratio) matrices.
# To index easily, reshape distances, indices and treatment into
# the (n_obs * self.ratio, 1) matrices and data frame.
distances = distances.T.flatten()
indices = indices.T.flatten()
treatment_scaled = pd.concat([treatment_scaled] * self.ratio, axis=0)
cond = (distances / np.sqrt(len(score_cols))) < sdcal
# Deduplicate the indices of the treatment group
t_idx_matched = np.unique(treatment_scaled.loc[cond].index)
# XXX: Should we deduplicate the indices of the control group too?
c_idx_matched = np.array(control_scaled.iloc[indices[cond]].index)
else:
assert len(score_cols) == 1, (
"Matching on multiple columns is only supported using the "
"replacement method (if matching on multiple columns, set "
"replace=True)."
)
# unpack score_cols for the single-variable matching case
score_col = score_cols[0]
if self.shuffle:
t_indices = self.random_state.permutation(treatment.index)
else:
t_indices = treatment.index
t_idx_matched = []
c_idx_matched = []
control["unmatched"] = True
for t_idx in t_indices:
dist = np.abs(
control.loc[control.unmatched, score_col]
- treatment.loc[t_idx, score_col]
)
c_idx_min = dist.idxmin()
if dist[c_idx_min] <= sdcal:
t_idx_matched.append(t_idx)
c_idx_matched.append(c_idx_min)
control.loc[c_idx_min, "unmatched"] = False
return data.loc[
np.concatenate([np.array(t_idx_matched), np.array(c_idx_matched)])
]