def plot_pianoroll()

in gan/utils/display_utils.py [0:0]


def plot_pianoroll(iteration, xs, fake_xs, min_pitch=45, max_pitch=85,
                   programs = [0, 0, 0, 0], save_dir=None):
    """ Plot a MultiTrack PianoRoll

    :param x: Multi Instrument PianoRoll Tensor
    :param min_pitch: Min Pitch / Min Y across all instruments.
    :param max_pitch: Max Pitch / Max Y across all instruments.
    :param programs: Program Number of the Tracks.
    :param file_name: Optional. File Name to save the plot.
    :return:
    """

    # Convert fake_x to numpy and convert -1 to 0
    xs = xs > 0
    fake_xs = fake_xs > 0

    channel_last = lambda x: np.moveaxis(np.array(x), 2, 0)
    xs = [channel_last(x) for x in xs]
    fake_xs = [channel_last(fake_x) for fake_x in fake_xs]

    assert len(xs[0].shape) == 3, 'Pianoroll shape must have 3 dims, Got %d' % len(xs[0].shape)
    n_tracks, time_step, _ = xs[0].shape

    plt.ion()
    fig = plt.figure(figsize=(15, 8))

    # gridspec inside gridspec
    outer_grid = gridspec.GridSpec(2, 2, wspace=0.1, hspace=0.2)

    for i in range(4):
        inner_grid = gridspec.GridSpecFromSubplotSpec(2, n_tracks,
                subplot_spec=outer_grid[i], wspace=0.0, hspace=0.0)

        x, fake_x = xs[i], fake_xs[i]

        for j, (a, b) in enumerate(itertools.product([1, 2], [1, 2, 3, 4])):

            ax = fig.add_subplot(inner_grid[j])

            if a == 1:
                nz = np.nonzero(x[b-1])
            else:
                nz = np.nonzero(fake_x[b-1])

            if programs:
                ax.set_xlabel('Time('+pretty_midi.program_to_instrument_class(programs[j%4])+')')

            if b == 1:
                ax.set_ylabel('Pitch')
            else:
                ax.set_yticks([])

            ax.scatter(nz[0], nz[1], s=np.pi * 3, color='bgrcmk'[b-1])
            ax.set_ylim(45, 85)
            ax.set_xlim(0, time_step)
            fig.add_subplot(ax)

    if isinstance(iteration, int):
        plt.suptitle('iteration: {}'.format(iteration), fontsize=20)
        filename = os.path.join(save_dir, 'sample_iteration_%05d.png' % iteration)
    else:
        plt.suptitle('Inference', fontsize=20)
        filename = os.path.join(save_dir, 'sample_inference.png')
    plt.savefig(filename)
    plt.close(fig)