def xval_fit_summary()

in vihds/plotting.py [0:0]


def xval_fit_summary(res, device_id, separatedInputs=False):
    """Summary plot of model-data fit for cross-validation results"""
    nplots = len(res.settings.signals)
    fs = 14

    all_locs = []
    if separatedInputs is True:
        nrows = len(res.settings.conditions)
        for i in range(nrows):
            dev_locs = np.where((res.devices == device_id) & (res.treatments[:, i] > 0.0))[0]
            _, indices = np.unique(res.treatments[dev_locs, i], return_index=True)
            all_locs.append(dev_locs[indices])
        f, axs = pp.subplots(nrows, nplots, sharex=True, sharey=True, figsize=(2.2 * nplots, 1.6 * nrows + 1.2),)
    else:
        nrows = 1
        dev_locs = np.where(res.devices == device_id)[0]
        _, indices = np.unique(res.treatments[dev_locs, :], return_index=True, axis=0)
        all_locs.append(dev_locs[indices])
        f, axs = pp.subplots(1, nplots, sharey=True, figsize=(2.2 * nplots, 2.8))

    for i, locs in enumerate(all_locs):
        colors = [cm.rainbow(x) for x in np.linspace(0, 1, np.shape(locs)[0])]  # pylint: disable=no-member
        for idx in range(nplots):
            if nrows > 1:
                ax = axs[i, idx]
            else:
                ax = axs[idx]

            w_mu = res.iw_predict_mu[locs, idx, :]
            w_std = res.iw_predict_std[locs, idx, :]
            ax.set_prop_cycle("color", colors)
            for mu, std in zip(w_mu, w_std):
                ax.fill_between(res.times, mu - 2 * std, mu + 2 * std, alpha=0.1)
            ax.plot(res.times, res.X_obs[locs, idx, :].T, ".", alpha=1, markersize=2)
            ax.plot(res.times, w_mu.T, "-", lw=2, alpha=0.75)
            ax.set_xlim(0.0, 17)
            ax.set_xticks([0, 5, 10, 15])
            ax.set_ylim(-0.2, 1.2)
            if (idx == 0) & (nrows > 1):
                ax.set_ylabel(
                    res.settings.conditions[i] + " dilution", labelpad=25, fontweight="bold", fontsize=fs,
                )
            if i == 0:
                ax.set_title(res.settings.signals[idx], fontsize=fs)

    # Global axis labels: add a big axis, then hide frame
    f.add_subplot(111, frameon=False)
    pp.tick_params(labelcolor="none", top=False, bottom=False, left=False, right=False)
    pp.xlabel("Time (h)", fontsize=fs, labelpad=7)
    pp.ylabel("Normalized output", fontsize=fs, labelpad=7)
    pp.tight_layout()
    sns.despine()

    return f