in vihds/plotting.py [0:0]
def xval_individual(res, device_id):
nplots = res.X_obs.shape[1]
colors = ["tab:gray", "r", "y", "c"]
maxs = np.max(res.X_obs, axis=(0, 2))
fs = 14
locs = np.where(res.devices == device_id)[0]
ids = np.argsort(res.ids[locs])
locs = locs[ids]
ntreatments = len(locs)
nrows = int(np.ceil(ntreatments / 2.0))
f = pp.figure(figsize=(12, 1.2 * nrows))
for col in range(2):
left = 0.1 + col * 0.5
bottom = 0.4 / nrows
width = 0.33 / nplots
dx = 0.38 / nplots
dy = (1 - bottom) / nrows
height = 0.8 * dy
for i in range(nrows):
loc = locs[i + col * nrows]
treatment_str = gen_treatment_str(res.settings.conditions, res.treatments[loc])
for idx, maxi in enumerate(maxs):
ax = f.add_subplot(nrows, 2 * nplots, col * nplots + (nrows - i - 1) * 2 * nplots + idx + 1,)
ax.set_position([left + idx * dx, bottom + (nrows - i - 1) * dy, width, height])
mu = res.iw_predict_mu[loc, idx, :]
std = res.iw_predict_std[loc, idx, :]
ax.fill_between(
res.times, (mu - 2 * std) / maxi, (mu + 2 * std) / maxi, alpha=0.25, color=colors[idx],
)
ax.plot(res.times, res.X_obs[loc, idx, :] / maxi, "k.", markersize=2)
ax.plot(res.times, mu / maxi, "-", lw=2, alpha=0.75, color=colors[idx])
ax.set_xlim(0.0, 17)
ax.set_xticks([0, 5, 10, 15])
ax.set_ylim(-0.2, 1.2)
ax.tick_params(axis="both", which="major", labelsize=fs)
if i == 0:
pp.title(res.settings.signals[idx], fontsize=fs)
# if i<nrows-1:
ax.set_xticklabels([])
if idx == 0:
ax.set_ylabel(treatment_str, labelpad=25, fontsize=fs - 2)
else:
ax.set_yticklabels([])
# Add labels
f.text(
left - 0.35 * dx, 0.5, "Normalized output", ha="center", va="center", rotation=90, fontsize=fs,
)
f.text(left + 2 * dx, 0, "Time (h)", ha="center", va="bottom", fontsize=fs)
sns.despine()
return f