in vihds/plotting.py [0:0]
def xval_fit_summary(res, device_id, separatedInputs=False):
"""Summary plot of model-data fit for cross-validation results"""
nplots = len(res.settings.signals)
fs = 14
all_locs = []
if separatedInputs is True:
nrows = len(res.settings.conditions)
for i in range(nrows):
dev_locs = np.where((res.devices == device_id) & (res.treatments[:, i] > 0.0))[0]
_, indices = np.unique(res.treatments[dev_locs, i], return_index=True)
all_locs.append(dev_locs[indices])
f, axs = pp.subplots(nrows, nplots, sharex=True, sharey=True, figsize=(2.2 * nplots, 1.6 * nrows + 1.2),)
else:
nrows = 1
dev_locs = np.where(res.devices == device_id)[0]
_, indices = np.unique(res.treatments[dev_locs, :], return_index=True, axis=0)
all_locs.append(dev_locs[indices])
f, axs = pp.subplots(1, nplots, sharey=True, figsize=(2.2 * nplots, 2.8))
for i, locs in enumerate(all_locs):
colors = [cm.rainbow(x) for x in np.linspace(0, 1, np.shape(locs)[0])] # pylint: disable=no-member
for idx in range(nplots):
if nrows > 1:
ax = axs[i, idx]
else:
ax = axs[idx]
w_mu = res.iw_predict_mu[locs, idx, :]
w_std = res.iw_predict_std[locs, idx, :]
ax.set_prop_cycle("color", colors)
for mu, std in zip(w_mu, w_std):
ax.fill_between(res.times, mu - 2 * std, mu + 2 * std, alpha=0.1)
ax.plot(res.times, res.X_obs[locs, idx, :].T, ".", alpha=1, markersize=2)
ax.plot(res.times, w_mu.T, "-", lw=2, alpha=0.75)
ax.set_xlim(0.0, 17)
ax.set_xticks([0, 5, 10, 15])
ax.set_ylim(-0.2, 1.2)
if (idx == 0) & (nrows > 1):
ax.set_ylabel(
res.settings.conditions[i] + " dilution", labelpad=25, fontweight="bold", fontsize=fs,
)
if i == 0:
ax.set_title(res.settings.signals[idx], fontsize=fs)
# Global axis labels: add a big axis, then hide frame
f.add_subplot(111, frameon=False)
pp.tick_params(labelcolor="none", top=False, bottom=False, left=False, right=False)
pp.xlabel("Time (h)", fontsize=fs, labelpad=7)
pp.ylabel("Normalized output", fontsize=fs, labelpad=7)
pp.tight_layout()
sns.despine()
return f