in vihds/plotting.py [0:0]
def xval_individual_2treatments(res, device_id):
"""Multi-panel plot for each sample, with treatments separated into 2 groups"""
nplots = res.X_obs.shape[1]
colors = ["tab:gray", "r", "y", "c"]
maxs = np.max(res.X_obs, axis=(0, 2))
fs = 14
both_locs = []
for col in range(2):
all_locs = np.where((res.devices == device_id) & (res.treatments[:, col] > 0.0))[0]
indices = np.argsort(res.treatments[all_locs, col])
both_locs.append(all_locs[indices])
ntreatments = max(map(len, both_locs))
f = pp.figure(figsize=(12, 1.5 * ntreatments))
for col, locs in enumerate(both_locs):
left = 0.1 + col * 0.5
bottom = 0.4 / ntreatments
width = 0.33 / nplots
dx = 0.38 / nplots
dy = (1 - bottom) / ntreatments
height = 0.8 * dy
for i, loc in enumerate(locs[:ntreatments]):
# TODO(ndalchau): Incorporate units into conditions specification (here we assume nM)
treatment_str = gen_treatment_str(res.settings.conditions, res.treatments[loc], unit="nM")
for idx, maxi in enumerate(maxs):
ax = f.add_subplot(
ntreatments, 2 * nplots, col * nplots + (ntreatments - i - 1) * 2 * nplots + idx + 1,
)
ax.set_position(
[left + idx * dx, bottom + (ntreatments - i - 1) * dy, width, height]
)
mu = res.iw_predict_mu[loc, idx, :]
std = res.iw_predict_std[loc, idx, :]
ax.fill_between(
res.times, (mu - 2 * std) / maxi, (mu + 2 * std) / maxi, alpha=0.25, color=colors[idx],
)
ax.plot(res.times, res.X_obs[loc, idx, :] / maxi, "k.", markersize=2)
ax.plot(res.times, mu / maxi, "-", lw=2, alpha=0.75, color=colors[idx])
ax.set_xlim(0.0, 17)
ax.set_xticks([0, 5, 10, 15])
ax.set_ylim(-0.2, 1.2)
ax.tick_params(axis="both", which="major", labelsize=fs)
if i == 0:
pp.title(res.settings.signals[idx], fontsize=fs)
if i < ntreatments - 1:
ax.set_xticklabels([])
if idx == 0:
ax.set_ylabel(treatment_str, labelpad=25, fontsize=fs - 2)
else:
ax.set_yticklabels([])
sns.despine()
# Add labels
f.text(
left - 0.35 * dx, 0.5, "Normalized output", ha="center", va="center", rotation=90, fontsize=fs,
)
f.text(left + 2 * dx, 0, "Time (h)", ha="center", va="bottom", fontsize=fs)
return f