in causalml/metrics/sensitivity.py [0:0]
def plot(sens_df, partial_rsqs_df=None, type="raw", ci=False, partial_rsqs=False):
"""Plot the results of a sensitivity analysis against unmeasured
Args:
sens_df (pandas.DataFrame): a data frame output from causalsens
partial_rsqs_d (pandas.DataFrame) : a data frame output from causalsens including partial rsqure
type (str, optional): the type of plot to draw, 'raw' or 'r.squared' are supported
ci (bool, optional): whether plot confidence intervals
partial_rsqs (bool, optional): whether plot partial rsquare results
"""
if type == "raw" and not ci:
fig, ax = plt.subplots()
y_max = round(sens_df["New ATE UB"].max() * 1.1, 4)
y_min = round(sens_df["New ATE LB"].min() * 0.9, 4)
x_max = round(sens_df.alpha.max() * 1.1, 4)
x_min = round(sens_df.alpha.min() * 0.9, 4)
plt.ylim(y_min, y_max)
plt.xlim(x_min, x_max)
ax.plot(sens_df.alpha, sens_df["New ATE"])
elif type == "raw" and ci:
fig, ax = plt.subplots()
y_max = round(sens_df["New ATE UB"].max() * 1.1, 4)
y_min = round(sens_df["New ATE LB"].min() * 0.9, 4)
x_max = round(sens_df.alpha.max() * 1.1, 4)
x_min = round(sens_df.alpha.min() * 0.9, 4)
plt.ylim(y_min, y_max)
plt.xlim(x_min, x_max)
ax.fill_between(
sens_df.alpha,
sens_df["New ATE LB"],
sens_df["New ATE UB"],
color="gray",
alpha=0.5,
)
ax.plot(sens_df.alpha, sens_df["New ATE"])
elif type == "r.squared" and ci:
fig, ax = plt.subplots()
y_max = round(sens_df["New ATE UB"].max() * 1.1, 4)
y_min = round(sens_df["New ATE LB"].min() * 0.9, 4)
plt.ylim(y_min, y_max)
ax.fill_between(
sens_df.rsqs,
sens_df["New ATE LB"],
sens_df["New ATE UB"],
color="gray",
alpha=0.5,
)
ax.plot(sens_df.rsqs, sens_df["New ATE"])
if partial_rsqs:
plt.scatter(
partial_rsqs_df.partial_rsqs,
list(sens_df[sens_df.alpha == 0]["New ATE"])
* partial_rsqs_df.shape[0],
marker="x",
color="red",
linewidth=10,
)
elif type == "r.squared" and not ci:
fig, ax = plt.subplots()
y_max = round(sens_df["New ATE UB"].max() * 1.1, 4)
y_min = round(sens_df["New ATE LB"].min() * 0.9, 4)
plt.ylim(y_min, y_max)
plt.plot(sens_df.rsqs, sens_df["New ATE"])
if partial_rsqs:
plt.scatter(
partial_rsqs_df.partial_rsqs,
list(sens_df[sens_df.alpha == 0]["New ATE"])
* partial_rsqs_df.shape[0],
marker="x",
color="red",
linewidth=10,
)