in vihds/plotting.py [0:0]
def combined_treatments(results, devices):
"""Compare model-data functional responses to inputs for multiple models"""
ndev = len(devices)
nres = len(results)
ms = 5
fs = 14
obs_mk = "x"
pred_mk = "o"
width = 0.2
lefts = [0.05, 0.57]
bottom = 0.3 / ndev
dx = 0.23
dy = (1 - bottom) / ndev
height = 0.9 * dy
c6_idx = 1
c12_idx = 0
ids = [2, 3]
colors = ["y", "c"]
f, ax = pp.subplots(ndev, 2 * nres, sharex=True, figsize=(9, 2.2 * ndev + 0.5))
for iu, device_id in enumerate(devices):
if ndev == 1:
row = ax
ytext = "Norm. fluorescence"
else:
row = ax[iu]
ytext = "Normalized fluorescence"
row[0].set_ylabel(results[0].pretty_devices[iu], labelpad=25, fontweight="bold", fontsize=fs)
for ir, res in enumerate(results):
locs = np.where(res.devices == device_id)[0]
OBS = np.transpose(res.X_obs[locs, -1, :], [1, 0])
IW = res.importance_weights[locs]
PREDICT = np.transpose(res.PREDICT[locs, :], [2, 0, 1])
STD = np.transpose(res.STD[locs, :], [2, 0, 1])
all_C6 = np.exp(res.treatments[:, c6_idx]) - 1
all_C12 = np.exp(res.treatments[:, c12_idx]) - 1
C6 = all_C6[locs]
C12 = all_C12[locs]
for j, color in zip(ids, colors):
mu = np.sum(IW * PREDICT[j], 1)
var = np.sum(IW * (PREDICT[j] ** 2 + STD[j] ** 2), 1) - mu ** 2
std = np.sqrt(var)
for k, (id, C) in enumerate(zip(ids, [C6, C12])):
ic = ir + k * nres
row[ic].errorbar(C, mu, yerr=std, fmt=pred_mk, mec="k", ms=ms, lw=1, color=color)
row[ic].semilogx(C, OBS[id], obs_mk, ms=ms, lw=1, color=color)
if ir > 0:
row[ir].set_yticklabels([])
row[ir + nres].set_yticklabels([])
for k in range(2):
ic = ir + k * nres
row[ic].set_position([lefts[k] + ir * dx, bottom + (ndev - iu - 1) * dy, width, height])
row[ic].set_xticks(np.logspace(0, 4, 3))
row[ic].set_ylim(-0.1, 1.1)
row[ic].set_yticks([0.0, 0.5, 1.0])
row[ic].tick_params(axis="both", which="major", labelsize=fs)
if iu == 0:
row[ic].set_title(res.label, fontsize=fs)
# Global axis labels: add a big axis, then hide frame
xlabels = ["C$_6$ (nM)", "C$_{12}$ (nM)"]
for k, xlabel in enumerate(xlabels):
f.add_subplot(
1, 2, k + 1, frameon=False, position=[lefts[k], bottom, width + (nres - 1) * dx, height + (ndev - 1) * dy],
)
pp.tick_params(labelcolor="none", top=False, bottom=False, left=False, right=False)
pp.xlabel(xlabel, fontsize=fs, labelpad=10)
pp.ylabel(ytext, fontsize=fs, labelpad=8)
sns.despine()
return f