def plot_density()

in viz_dataset.py [0:0]


def plot_density(loglik_fn, spatial_locations, index, S_mean, S_std, savepath, dataset_name, device, text=None, fp64=False):
    N = 50

    x = np.linspace(BBOXES[dataset_name][0], BBOXES[dataset_name][1], N)
    y = np.linspace(BBOXES[dataset_name][2], BBOXES[dataset_name][3], N)
    s = np.stack([x, y], axis=1)

    X, Y = np.meshgrid(s[:, 0], s[:, 1])
    S = np.stack([X.reshape(-1), Y.reshape(-1)], axis=1)
    S = torch.tensor(S).to(device)
    S = S.double() if fp64 else S.float()
    S = (S - S_mean.to(S)) / S_std.to(S)
    logp = loglik_fn(S)

    if MAPS[dataset_name]:
        map_img = plt.imread(MAPS[dataset_name])
        fig, ax = plt.subplots(figsize=(FIGSIZE, FIGSIZE * map_img.shape[0] / map_img.shape[1]))
        ax.imshow(map_img, zorder=0, extent=BBOXES[dataset_name])
    else:
        fig, ax = plt.subplots(figsize=(FIGSIZE, FIGSIZE))

    Z = logp.exp().detach().cpu().numpy().reshape(N, N)
    ax.contourf(X, Y, Z, levels=20, alpha=0.7, cmap='RdGy')

    spatial_locations = spatial_locations * np.array(S_std) + np.array(S_mean)
    ax.scatter(spatial_locations[:, 0], spatial_locations[:, 1], s=20**2, alpha=1.0, marker="x", color="k")

    ax.set_xlim(BBOXES[dataset_name][0], BBOXES[dataset_name][1])
    ax.set_ylim(BBOXES[dataset_name][2], BBOXES[dataset_name][3])

    if text:
        txt = ax.text(0.15, 0.9, text,
                      horizontalalignment="center",
                      verticalalignment="center",
                      transform=ax.transAxes,
                      size=16,
                      color='white')
        txt.set_path_effects([PathEffects.withStroke(linewidth=5, foreground='black')])

    plt.axis('off')
    os.makedirs(os.path.join(savepath, "figs"), exist_ok=True)
    np.savez(f"{savepath}/figs/data{index}.npz", **{"X": X, "Y": Y, "Z": Z, "spatial_locations": spatial_locations})
    plt.savefig(os.path.join(savepath, "figs", f"density{index}.png"), bbox_inches='tight', dpi=DPI)
    plt.close()