def combined_treatments()

in vihds/plotting.py [0:0]


def combined_treatments(results, devices):
    """Compare model-data functional responses to inputs for multiple models"""
    ndev = len(devices)
    nres = len(results)

    ms = 5
    fs = 14
    obs_mk = "x"
    pred_mk = "o"

    width = 0.2
    lefts = [0.05, 0.57]
    bottom = 0.3 / ndev
    dx = 0.23
    dy = (1 - bottom) / ndev
    height = 0.9 * dy
    c6_idx = 1
    c12_idx = 0
    ids = [2, 3]
    colors = ["y", "c"]
    f, ax = pp.subplots(ndev, 2 * nres, sharex=True, figsize=(9, 2.2 * ndev + 0.5))
    for iu, device_id in enumerate(devices):
        if ndev == 1:
            row = ax
            ytext = "Norm. fluorescence"
        else:
            row = ax[iu]
            ytext = "Normalized fluorescence"
        row[0].set_ylabel(results[0].pretty_devices[iu], labelpad=25, fontweight="bold", fontsize=fs)
        for ir, res in enumerate(results):
            locs = np.where(res.devices == device_id)[0]
            OBS = np.transpose(res.X_obs[locs, -1, :], [1, 0])
            IW = res.importance_weights[locs]
            PREDICT = np.transpose(res.PREDICT[locs, :], [2, 0, 1])
            STD = np.transpose(res.STD[locs, :], [2, 0, 1])
            all_C6 = np.exp(res.treatments[:, c6_idx]) - 1
            all_C12 = np.exp(res.treatments[:, c12_idx]) - 1
            C6 = all_C6[locs]
            C12 = all_C12[locs]

            for j, color in zip(ids, colors):
                mu = np.sum(IW * PREDICT[j], 1)
                var = np.sum(IW * (PREDICT[j] ** 2 + STD[j] ** 2), 1) - mu ** 2
                std = np.sqrt(var)
                for k, (id, C) in enumerate(zip(ids, [C6, C12])):
                    ic = ir + k * nres
                    row[ic].errorbar(C, mu, yerr=std, fmt=pred_mk, mec="k", ms=ms, lw=1, color=color)
                    row[ic].semilogx(C, OBS[id], obs_mk, ms=ms, lw=1, color=color)

            if ir > 0:
                row[ir].set_yticklabels([])
                row[ir + nres].set_yticklabels([])
            for k in range(2):
                ic = ir + k * nres
                row[ic].set_position([lefts[k] + ir * dx, bottom + (ndev - iu - 1) * dy, width, height])
                row[ic].set_xticks(np.logspace(0, 4, 3))
                row[ic].set_ylim(-0.1, 1.1)
                row[ic].set_yticks([0.0, 0.5, 1.0])
                row[ic].tick_params(axis="both", which="major", labelsize=fs)
                if iu == 0:
                    row[ic].set_title(res.label, fontsize=fs)

    # Global axis labels: add a big axis, then hide frame
    xlabels = ["C$_6$ (nM)", "C$_{12}$ (nM)"]
    for k, xlabel in enumerate(xlabels):
        f.add_subplot(
            1, 2, k + 1, frameon=False, position=[lefts[k], bottom, width + (nres - 1) * dx, height + (ndev - 1) * dy],
        )
        pp.tick_params(labelcolor="none", top=False, bottom=False, left=False, right=False)
        pp.xlabel(xlabel, fontsize=fs, labelpad=10)
        pp.ylabel(ytext, fontsize=fs, labelpad=8)

    sns.despine()

    return f