in interpolation.py [0:0]
def plot_dtw_comparison(dpt_true, dpt_po, dpt_fa, dpt_umap,
idx_full, idx_po, idx_fa, idx_umap, data_full,
x_predicted_po, x_predicted_fa, x_predicted_umap,
col_names, fout,
n_plt = 15, n2 = 3, win = 5, pl_size = 2):
time_true, ix_true = get_time_and_idx(dpt_true, idx_full)
time_po, ix_po = get_time_and_idx(dpt_po, idx_po)
time_fa, ix_fa = get_time_and_idx(dpt_fa, idx_fa)
time_umap, ix_umap = get_time_and_idx(dpt_umap, idx_umap)
N = len(col_names)
n1 = n_plt // n2
if n1*n2 < n_plt:
n1 += 1
if n1 == 1:
n1 = 2
if n2 == 1:
n2 = 2
fig, axs = plt.subplots(n1, n2, sharey=False, figsize=(n2*pl_size + 2, n1*pl_size))
i = 0
dtw_po = []
dtw_fa = []
dtw_umap = []
for i1 in range(n1):
for i2 in range(n2):
axs[i1, i2].grid('off')
axs[i1, i2].yaxis.set_tick_params(labelsize=fs)
axs[i1, i2].xaxis.set_tick_params(labelsize=fs)
if i < N:
df_true = pd.DataFrame(data_full[idx_full[ix_true], i], columns=['gene'])
y_smooth_true = df_true.rolling(window=win, min_periods=1).mean()['gene'].values
df_po = pd.DataFrame(x_predicted_po[idx_pm[ix_po], i], columns=['gene'])
y_smooth_po = df_po.rolling(window=win, min_periods=1).mean()['gene'].values
df _fa= pd.DataFrame(x_predicted_fa[idx_bm[ix_fa], i], columns=['gene'])
y_smooth_fa = df_fa.rolling(window=win, min_periods=1).mean()['gene'].values
df_umap= pd.DataFrame(x_predicted_umap[idx_bm[ix_fa], i], columns=['gene'])
y_smooth_umap = df_umap.rolling(window=win, min_periods=1).mean()['gene'].values
distance, path = fastdtw(y_smooth_true, y_smooth_po, dist=euclidean)
dtw_po.append(distance)
distance, path = fastdtw(y_smooth_true, y_smooth_fa, dist=euclidean)
dtw_fa.append(distance)
distance, path = fastdtw(y_smooth_true, y_smooth_umap, dist=euclidean)
dtw_umap.append(distance)
if i < n_plt:
marker = col_names[i]
axs[i1, i2].plot(time_true[ix_true], y_smooth_true, c=cpal[0], linewidth=lw*2)
axs[i1, i2].plot(time_po[ix_po], y_smooth_po, c=cpal[1], linewidth=lw)
axs[i1, i2].plot(time_fa[ix_fa], y_smooth_fa, c=cpal[2], linewidth=lw)
axs[i1, i2].plot(time_umap[ix_fa], y_smooth_umap, c=cpal[3], linewidth=lw)
axs[i1, i2].set_title(marker, fontsize=fs)
else:
axs[i1, i2].axis('off')
i+=1
dtw_po = np.array(dtw_po)
dtw_fa = np.array(dtw_fa)
dtw_umap = np.array(dtw_umap)
axs[i1, i2].legend(['True', f'Poincaré: {np.median(dtw_po):.2f}', f'ForceAtals2: {np.median(dtw_fa):.2f}', f'UMAP: {np.median(dtw_umap):.2f}'],
bbox_to_anchor=(1.4, 0.5), fontsize=fs)
# axs[3, 1].legend(['True', f'Poincaré: {dtw_po:.2f}', f'{method}: {dtw_bm:.2f}'],
# loc='center left', bbox_to_anchor=(1.4, 0.5), fontsize=fs)
plt.xlabel('pseudotime', fontsize=fs)
fig.tight_layout()
plt.savefig(fout + '_compare_interpolation.pdf', format='pdf')
return dtw_po, dtw_fa, dtw_umap