def xval_individual_2treatments()

in vihds/plotting.py [0:0]


def xval_individual_2treatments(res, device_id):
    """Multi-panel plot for each sample, with treatments separated into 2 groups"""
    nplots = res.X_obs.shape[1]
    colors = ["tab:gray", "r", "y", "c"]
    maxs = np.max(res.X_obs, axis=(0, 2))

    fs = 14
    both_locs = []
    for col in range(2):
        all_locs = np.where((res.devices == device_id) & (res.treatments[:, col] > 0.0))[0]
        indices = np.argsort(res.treatments[all_locs, col])
        both_locs.append(all_locs[indices])

    ntreatments = max(map(len, both_locs))
    f = pp.figure(figsize=(12, 1.5 * ntreatments))
    for col, locs in enumerate(both_locs):
        left = 0.1 + col * 0.5
        bottom = 0.4 / ntreatments
        width = 0.33 / nplots
        dx = 0.38 / nplots
        dy = (1 - bottom) / ntreatments
        height = 0.8 * dy
        for i, loc in enumerate(locs[:ntreatments]):
            # TODO(ndalchau): Incorporate units into conditions specification (here we assume nM)
            treatment_str = gen_treatment_str(res.settings.conditions, res.treatments[loc], unit="nM")

            for idx, maxi in enumerate(maxs):
                ax = f.add_subplot(
                    ntreatments, 2 * nplots, col * nplots + (ntreatments - i - 1) * 2 * nplots + idx + 1,
                )
                ax.set_position(
                    [left + idx * dx, bottom + (ntreatments - i - 1) * dy, width, height]
                )

                mu = res.iw_predict_mu[loc, idx, :]
                std = res.iw_predict_std[loc, idx, :]

                ax.fill_between(
                    res.times, (mu - 2 * std) / maxi, (mu + 2 * std) / maxi, alpha=0.25, color=colors[idx],
                )
                ax.plot(res.times, res.X_obs[loc, idx, :] / maxi, "k.", markersize=2)
                ax.plot(res.times, mu / maxi, "-", lw=2, alpha=0.75, color=colors[idx])
                ax.set_xlim(0.0, 17)
                ax.set_xticks([0, 5, 10, 15])
                ax.set_ylim(-0.2, 1.2)
                ax.tick_params(axis="both", which="major", labelsize=fs)

                if i == 0:
                    pp.title(res.settings.signals[idx], fontsize=fs)
                if i < ntreatments - 1:
                    ax.set_xticklabels([])
                if idx == 0:
                    ax.set_ylabel(treatment_str, labelpad=25, fontsize=fs - 2)
                else:
                    ax.set_yticklabels([])

                sns.despine()

        # Add labels
        f.text(
            left - 0.35 * dx, 0.5, "Normalized output", ha="center", va="center", rotation=90, fontsize=fs,
        )
        f.text(left + 2 * dx, 0, "Time (h)", ha="center", va="bottom", fontsize=fs)

    return f