in viz_dataset.py [0:0]
def plot_kde(coords, S_std, S_mean, savepath, dataset_name, text=None, name=None):
name = f"{dataset_name}_density" if name is None else name
coords = coords * S_std.to(coords) + S_mean.to(coords)
longs = coords[:, 0].detach().cpu().numpy()
lats = coords[:, 1].detach().cpu().numpy()
if MAPS[dataset_name]:
map_img = plt.imread(MAPS[dataset_name])
fig, ax = plt.subplots(figsize=(FIGSIZE, FIGSIZE * map_img.shape[0] / map_img.shape[1]))
ax.imshow(map_img, zorder=0, extent=BBOXES[dataset_name])
else:
fig, ax = plt.subplots(figsize=(FIGSIZE, FIGSIZE))
kernel = gaussian_kde(np.stack([longs, lats], axis=0))
kernel.inv_cov = np.diag(np.diag(kernel.inv_cov))
X, Y = np.mgrid[BBOXES[dataset_name][0]:BBOXES[dataset_name][1]:100j, BBOXES[dataset_name][2]:BBOXES[dataset_name][3]:100j]
positions = np.vstack([X.ravel(), Y.ravel()])
Z = np.reshape(kernel(positions).T, X.shape)
ax.contourf(X, Y, Z, levels=10, alpha=0.6, cmap='RdGy')
ax.set_xlim(BBOXES[dataset_name][0], BBOXES[dataset_name][1])
ax.set_ylim(BBOXES[dataset_name][2], BBOXES[dataset_name][3])
if text is not None:
txt = ax.text(0.15, 0.9, text,
horizontalalignment="center",
verticalalignment="center",
transform=ax.transAxes,
size=16,
color='white')
txt.set_path_effects([PathEffects.withStroke(linewidth=5, foreground='black')])
plt.axis('off')
os.makedirs(os.path.join(savepath, f"{dataset_name}"), exist_ok=True)
plt.savefig(os.path.join(savepath, f"{dataset_name}", f"{name}.png"), bbox_inches='tight', dpi=DPI)
plt.close()