in gan/utils/display_utils.py [0:0]
def plot_pianoroll(iteration, xs, fake_xs, min_pitch=45, max_pitch=85,
programs = [0, 0, 0, 0], save_dir=None):
""" Plot a MultiTrack PianoRoll
:param x: Multi Instrument PianoRoll Tensor
:param min_pitch: Min Pitch / Min Y across all instruments.
:param max_pitch: Max Pitch / Max Y across all instruments.
:param programs: Program Number of the Tracks.
:param file_name: Optional. File Name to save the plot.
:return:
"""
# Convert fake_x to numpy and convert -1 to 0
xs = xs > 0
fake_xs = fake_xs > 0
channel_last = lambda x: np.moveaxis(np.array(x), 2, 0)
xs = [channel_last(x) for x in xs]
fake_xs = [channel_last(fake_x) for fake_x in fake_xs]
assert len(xs[0].shape) == 3, 'Pianoroll shape must have 3 dims, Got %d' % len(xs[0].shape)
n_tracks, time_step, _ = xs[0].shape
plt.ion()
fig = plt.figure(figsize=(15, 8))
# gridspec inside gridspec
outer_grid = gridspec.GridSpec(2, 2, wspace=0.1, hspace=0.2)
for i in range(4):
inner_grid = gridspec.GridSpecFromSubplotSpec(2, n_tracks,
subplot_spec=outer_grid[i], wspace=0.0, hspace=0.0)
x, fake_x = xs[i], fake_xs[i]
for j, (a, b) in enumerate(itertools.product([1, 2], [1, 2, 3, 4])):
ax = fig.add_subplot(inner_grid[j])
if a == 1:
nz = np.nonzero(x[b-1])
else:
nz = np.nonzero(fake_x[b-1])
if programs:
ax.set_xlabel('Time('+pretty_midi.program_to_instrument_class(programs[j%4])+')')
if b == 1:
ax.set_ylabel('Pitch')
else:
ax.set_yticks([])
ax.scatter(nz[0], nz[1], s=np.pi * 3, color='bgrcmk'[b-1])
ax.set_ylim(45, 85)
ax.set_xlim(0, time_step)
fig.add_subplot(ax)
if isinstance(iteration, int):
plt.suptitle('iteration: {}'.format(iteration), fontsize=20)
filename = os.path.join(save_dir, 'sample_iteration_%05d.png' % iteration)
else:
plt.suptitle('Inference', fontsize=20)
filename = os.path.join(save_dir, 'sample_inference.png')
plt.savefig(filename)
plt.close(fig)