def xval_treatments()

in vihds/plotting.py [0:0]


def xval_treatments(res, devices):
    """Compare the final simulated points against the equivalent data-points to establish functional response"""
    nplots = len(res.settings.signals)
    ndev = len(devices)

    ms = 5
    fs = 14
    obs_mk = "x"
    pred_mk = "o"
    colors = ["g", "r", "b"]
    edges = ["darkgreen", "darkred", "darkblue"]

    f, axs = pp.subplots(ndev, nplots, sharex=True, sharey=True, figsize=(9, 2.2 * ndev))
    for iu, device_id in enumerate(devices):
        locs = np.where(res.devices == device_id)[0]
        input_values = []
        for ci, _ in enumerate(res.settings.conditions):
            vs = np.exp(res.treatments[:, ci]) - 1
            input_values.append(vs[locs])

        for j, signal in enumerate(res.settings.signals):
            if ndev > 1:
                ax = axs[iu, j]
            else:
                ax = axs[j]
            mu = res.iw_predict_mu[locs, j, -1]
            std = res.iw_predict_std[locs, j, -1]
            for ci, cvalues in enumerate(input_values):
                ax.errorbar(
                    cvalues, mu, yerr=std, fmt=pred_mk, ms=ms, lw=1, mec=edges[ci], color=colors[ci], zorder=ci,
                )
                ax.semilogx(
                    cvalues, res.X_obs[locs, j, -1], "k" + obs_mk, ms=ms, lw=1, color=edges[ci], zorder=ci + 20,
                )
            ax.set_ylim(-0.1, 1.1)
            ax.tick_params(axis="both", which="major", labelsize=fs)
            ax.set_xticks(np.logspace(0, 4, 3))
            if j == 0:
                ax.set_ylabel(
                    res.settings.devices[iu], labelpad=25, fontweight="bold", fontsize=fs,
                )
            if iu == 0:
                ax.set_title(signal, fontsize=fs)

    # Add legend to one of the panels
    if ndev > 1:
        ytext = "Normalized fluorescence"
        ax = axs[0, nplots - 1]
    else:
        ytext = "Norm. fluorescence"
        ax = axs[nplots - 1]
    dstr = list(map(lambda s: s + " (data)", res.settings.conditions))
    mstr = list(map(lambda s: s + " (model)", res.settings.conditions))
    ax.legend(labels=dstr + mstr)

    # Global axis labels: add a big axis, then hide frame
    f.add_subplot(111, frameon=False)
    pp.tick_params(labelcolor="none", top=False, bottom=False, left=False, right=False)
    pp.xlabel(" / ".join(res.settings.conditions), fontsize=fs, labelpad=7)
    # pp.xlabel("C$_6$ / C$_{12}$ (nM)", fontsize=fs, labelpad=7)
    pp.ylabel(ytext, fontsize=fs, labelpad=7)
    sns.despine()

    return f