def plot_kde()

in viz_dataset.py [0:0]


def plot_kde(coords, S_std, S_mean, savepath, dataset_name, text=None, name=None):
    name = f"{dataset_name}_density" if name is None else name

    coords = coords * S_std.to(coords) + S_mean.to(coords)
    longs = coords[:, 0].detach().cpu().numpy()
    lats = coords[:, 1].detach().cpu().numpy()

    if MAPS[dataset_name]:
        map_img = plt.imread(MAPS[dataset_name])
        fig, ax = plt.subplots(figsize=(FIGSIZE, FIGSIZE * map_img.shape[0] / map_img.shape[1]))
        ax.imshow(map_img, zorder=0, extent=BBOXES[dataset_name])
    else:
        fig, ax = plt.subplots(figsize=(FIGSIZE, FIGSIZE))

    kernel = gaussian_kde(np.stack([longs, lats], axis=0))
    kernel.inv_cov = np.diag(np.diag(kernel.inv_cov))
    X, Y = np.mgrid[BBOXES[dataset_name][0]:BBOXES[dataset_name][1]:100j, BBOXES[dataset_name][2]:BBOXES[dataset_name][3]:100j]
    positions = np.vstack([X.ravel(), Y.ravel()])
    Z = np.reshape(kernel(positions).T, X.shape)
    ax.contourf(X, Y, Z, levels=10, alpha=0.6, cmap='RdGy')
    ax.set_xlim(BBOXES[dataset_name][0], BBOXES[dataset_name][1])
    ax.set_ylim(BBOXES[dataset_name][2], BBOXES[dataset_name][3])

    if text is not None:
        txt = ax.text(0.15, 0.9, text,
                      horizontalalignment="center",
                      verticalalignment="center",
                      transform=ax.transAxes,
                      size=16,
                      color='white')
        txt.set_path_effects([PathEffects.withStroke(linewidth=5, foreground='black')])

    plt.axis('off')
    os.makedirs(os.path.join(savepath, f"{dataset_name}"), exist_ok=True)
    plt.savefig(os.path.join(savepath, f"{dataset_name}", f"{name}.png"), bbox_inches='tight', dpi=DPI)
    plt.close()