in reinvent-labs/lab-2/utils/display_utils.py [0:0]
def show_pianoroll(xs, min_pitch=45, max_pitch=85,
programs = [0, 0, 0, 0], save_dir=None):
""" Plot a MultiTrack PianoRoll
:param x: Multi Instrument PianoRoll Tensor
:param min_pitch: Min Pitch / Min Y across all instruments.
:param max_pitch: Max Pitch / Max Y across all instruments.
:param programs: Program Number of the Tracks.
:param file_name: Optional. File Name to save the plot.
:return:
"""
# Convert fake_x to numpy and convert -1 to 0
xs = xs > 0
channel_last = lambda x: np.moveaxis(np.array(x), 2, 0)
xs = [channel_last(x) for x in xs]
assert len(xs[0].shape) == 3, 'Pianoroll shape must have 3 dims, Got %d' % len(xs[0].shape)
n_tracks, time_step, _ = xs[0].shape
plt.ion()
fig = plt.figure(figsize=(15, 4))
x = xs[0]
for j in range(4):
b = j+1
ax = fig.add_subplot(1,4,b)
nz = np.nonzero(x[b-1])
if programs:
ax.set_xlabel('Time('+pretty_midi.program_to_instrument_class(programs[j%4])+')')
if (j+1)== 1:
ax.set_ylabel('Pitch')
else:
ax.set_yticks([])
ax.scatter(nz[0], nz[1], s=np.pi * 3, color='bgrcmk'[b-1])
ax.set_ylim(45, 85)
ax.set_xlim(0, time_step)
fig.add_subplot(ax)