def show_pianoroll()

in gan/utils/display_utils.py [0:0]


def show_pianoroll(xs, min_pitch=45, max_pitch=85,
                   programs = [0, 0, 0, 0], save_dir=None):
    """ Plot a MultiTrack PianoRoll

    :param x: Multi Instrument PianoRoll Tensor
    :param min_pitch: Min Pitch / Min Y across all instruments.
    :param max_pitch: Max Pitch / Max Y across all instruments.
    :param programs: Program Number of the Tracks.
    :param file_name: Optional. File Name to save the plot.
    :return:
    """

    # Convert fake_x to numpy and convert -1 to 0
    xs = xs > 0

    channel_last = lambda x: np.moveaxis(np.array(x), 2, 0)
    xs = [channel_last(x) for x in xs]

    assert len(xs[0].shape) == 3, 'Pianoroll shape must have 3 dims, Got %d' % len(xs[0].shape)
    n_tracks, time_step, _ = xs[0].shape

    plt.ion()
    fig = plt.figure(figsize=(15, 4))
    
    x = xs[0]
    
    for j in range(4):
        b = j+1
        ax = fig.add_subplot(1,4,b)
        nz = np.nonzero(x[b-1])

        if programs:
            ax.set_xlabel('Time('+pretty_midi.program_to_instrument_class(programs[j%4])+')')

        if (j+1)== 1:
            ax.set_ylabel('Pitch')
        else:
            ax.set_yticks([])

        ax.scatter(nz[0], nz[1], s=np.pi * 3, color='bgrcmk'[b-1])
        ax.set_ylim(45, 85)
        ax.set_xlim(0, time_step)
        fig.add_subplot(ax)