def plot_dtw_comparison()

in interpolation.py [0:0]


def plot_dtw_comparison(dpt_true, dpt_po, dpt_fa, dpt_umap, 
    idx_full, idx_po, idx_fa, idx_umap, data_full, 
    x_predicted_po, x_predicted_fa, x_predicted_umap, 
    col_names, fout,
    n_plt = 15, n2 = 3, win = 5, pl_size = 2):

    time_true, ix_true = get_time_and_idx(dpt_true, idx_full)
    time_po, ix_po = get_time_and_idx(dpt_po, idx_po)
    time_fa, ix_fa = get_time_and_idx(dpt_fa, idx_fa)
    time_umap, ix_umap = get_time_and_idx(dpt_umap, idx_umap)

    N = len(col_names)

    n1 = n_plt // n2
    if n1*n2 < n_plt:
        n1 += 1

    if n1 == 1:
        n1 = 2
    if n2 == 1:
        n2 = 2
    

    fig, axs = plt.subplots(n1, n2, sharey=False, figsize=(n2*pl_size + 2, n1*pl_size))

    i = 0
    dtw_po = []
    dtw_fa = []
    dtw_umap = []
    for i1 in range(n1):
        for i2 in range(n2):
            axs[i1, i2].grid('off')
            axs[i1, i2].yaxis.set_tick_params(labelsize=fs)
            axs[i1, i2].xaxis.set_tick_params(labelsize=fs)
            if i < N:
                df_true = pd.DataFrame(data_full[idx_full[ix_true], i], columns=['gene'])
                y_smooth_true = df_true.rolling(window=win, min_periods=1).mean()['gene'].values

                df_po = pd.DataFrame(x_predicted_po[idx_pm[ix_po], i], columns=['gene'])
                y_smooth_po = df_po.rolling(window=win, min_periods=1).mean()['gene'].values

                df _fa= pd.DataFrame(x_predicted_fa[idx_bm[ix_fa], i], columns=['gene'])
                y_smooth_fa = df_fa.rolling(window=win, min_periods=1).mean()['gene'].values

                df_umap= pd.DataFrame(x_predicted_umap[idx_bm[ix_fa], i], columns=['gene'])
                y_smooth_umap = df_umap.rolling(window=win, min_periods=1).mean()['gene'].values

                distance, path = fastdtw(y_smooth_true, y_smooth_po, dist=euclidean)
                dtw_po.append(distance)
                    
                distance, path = fastdtw(y_smooth_true, y_smooth_fa, dist=euclidean)
                dtw_fa.append(distance)

                distance, path = fastdtw(y_smooth_true, y_smooth_umap, dist=euclidean)
                dtw_umap.append(distance)

                if i < n_plt:
                    marker = col_names[i]
                    axs[i1, i2].plot(time_true[ix_true], y_smooth_true, c=cpal[0], linewidth=lw*2)  
                    axs[i1, i2].plot(time_po[ix_po], y_smooth_po, c=cpal[1], linewidth=lw)                    
                    axs[i1, i2].plot(time_fa[ix_fa], y_smooth_fa, c=cpal[2], linewidth=lw)
                    axs[i1, i2].plot(time_umap[ix_fa], y_smooth_umap, c=cpal[3], linewidth=lw)

                    axs[i1, i2].set_title(marker, fontsize=fs)            
                else:
                    axs[i1, i2].axis('off')
                
            i+=1

    dtw_po = np.array(dtw_po)
    dtw_fa = np.array(dtw_fa)
    dtw_umap = np.array(dtw_umap)

    axs[i1, i2].legend(['True', f'Poincaré: {np.median(dtw_po):.2f}', f'ForceAtals2: {np.median(dtw_fa):.2f}', f'UMAP: {np.median(dtw_umap):.2f}'], 
                     bbox_to_anchor=(1.4, 0.5), fontsize=fs)
                     
    # axs[3, 1].legend(['True', f'Poincaré: {dtw_po:.2f}', f'{method}: {dtw_bm:.2f}'], 
    #                  loc='center left', bbox_to_anchor=(1.4, 0.5), fontsize=fs)

    plt.xlabel('pseudotime', fontsize=fs)
    fig.tight_layout()

    plt.savefig(fout + '_compare_interpolation.pdf', format='pdf')

    return dtw_po, dtw_fa, dtw_umap