in DataLoaders/RandomWalks.py [0:0]
def vis_walk(walk):
'''
Args:
walk: (nT+1) X 2 array
Returns:
im: 200 X 200 X 4 numpy array
'''
t = walk.shape[0]
xs = walk[:,0]
ys = walk[:,1]
color_inds = np.linspace(0, 255, t).astype(np.int).tolist()
cs = plot_util.colormap[color_inds, :]
fig = plt.figure(figsize=(4, 4), dpi=50)
ax = fig.subplots()
ax.scatter(xs, ys, c=cs)
ax.set_xlim(-2, 2)
ax.set_ylim(-2, 2)
ax.set_aspect('equal', 'box')
ax.tick_params(
axis='x',
which='both',
bottom=False,
top=False,
labelbottom=False)
ax.tick_params(
axis='y',
which='both',
left=False,
right=False,
labelleft=False)
fig.tight_layout()
fname = '/tmp/' + ''.join(stdlib_random.choices(string.ascii_letters, k=8)) + '.png'
fig.savefig(fname)
plt.close(fig)
im = plt.imread(fname)
os.remove(fname)
return im