in causalml/metrics/sensitivity.py [0:0]
def plot(self, sens_df, partial_rsqs_df=None, type='raw', ci=False, partial_rsqs=False):
"""Plot the results of a sensitivity analysis against unmeasured
Args:
sens_df (pandas.DataFrame): a data frame output from causalsens
partial_rsqs_d (pandas.DataFrame) : a data frame output from causalsens including partial rsqure
type (str, optional): the type of plot to draw, 'raw' or 'r.squared' are supported
ci (bool, optional): whether plot confidence intervals
partial_rsqs (bool, optional): whether plot partial rsquare results
"""
if type == 'raw' and not ci:
fig, ax = plt.subplots()
y_max = round(sens_df['New ATE UB'].max()*1.1, 4)
y_min = round(sens_df['New ATE LB'].min()*0.9, 4)
x_max = round(sens_df.alpha.max()*1.1, 4)
x_min = round(sens_df.alpha.min()*0.9, 4)
plt.ylim(y_min, y_max)
plt.xlim(x_min, x_max)
ax.plot(sens_df.alpha, sens_df['New ATE'])
elif type == 'raw' and ci:
fig, ax = plt.subplots()
y_max = round(sens_df['New ATE UB'].max()*1.1, 4)
y_min = round(sens_df['New ATE LB'].min()*0.9, 4)
x_max = round(sens_df.alpha.max()*1.1, 4)
x_min = round(sens_df.alpha.min()*0.9, 4)
plt.ylim(y_min, y_max)
plt.xlim(x_min, x_max)
ax.fill_between(sens_df.alpha, sens_df['New ATE LB'], sens_df['New ATE UB'], color='gray', alpha=0.5)
ax.plot(sens_df.alpha, sens_df['New ATE'])
elif type == 'r.squared' and ci:
fig, ax = plt.subplots()
y_max = round(sens_df['New ATE UB'].max()*1.1, 4)
y_min = round(sens_df['New ATE LB'].min()*0.9, 4)
plt.ylim(y_min, y_max)
ax.fill_between(sens_df.rsqs, sens_df['New ATE LB'], sens_df['New ATE UB'], color='gray', alpha=0.5)
ax.plot(sens_df.rsqs, sens_df['New ATE'])
if partial_rsqs:
plt.scatter(partial_rsqs_df.partial_rsqs,
list(sens_df[sens_df.alpha == 0]['New ATE']) * partial_rsqs_df.shape[0],
marker='x', color="red", linewidth=10)
elif type == 'r.squared' and not ci:
fig, ax = plt.subplots()
y_max = round(sens_df['New ATE UB'].max()*1.1, 4)
y_min = round(sens_df['New ATE LB'].min()*0.9, 4)
plt.ylim(y_min, y_max)
plt.plot(sens_df.rsqs, sens_df['New ATE'])
if partial_rsqs:
plt.scatter(partial_rsqs_df.partial_rsqs,
list(sens_df[sens_df.alpha == 0]['New ATE']) * partial_rsqs_df.shape[0],
marker='x', color="red", linewidth=10)