def plot_metrics()

in gan/utils/metrics_utils.py [0:0]


    def plot_metrics(self):
        num_instruments = 4
        plt.ion()
        sns.set()
        fig, axs = plt.subplots(len(self.metrics), num_instruments, sharex=True, figsize=(60, 30))
        fig.tight_layout()
        plt.xscale('log')
        
        for instrument_idx in range(num_instruments):
            for metric_idx, metric_creator in enumerate(self.metrics):
                axs[metric_idx][instrument_idx].tick_params(axis='both', which='major', labelsize=30)
                axs[metric_idx][instrument_idx].tick_params(axis='both', which='minor', labelsize=30)
                
                metric_data = self.metrics[metric_creator]
                
                # Plot reference line
                axs[metric_idx][instrument_idx].plot(
                    [x[0] for x in metric_data['per_iteration']],
                    np.ones(len(metric_data['per_iteration'])) * metric_data['reference'][instrument_idx],
                    'r',
                    linewidth=10,
                    alpha=0.7
                )
                
                # Plot per-iteration metrics
                axs[metric_idx][instrument_idx].scatter(
                    [x[0] for x in metric_data['per_iteration']],
                    [x[1][instrument_idx] for x in metric_data['per_iteration']],
                    linewidth=10
                )
                

        for instrument_idx in range(num_instruments):
            label = "Iterations (Instrument {})".format(instrument_idx)
            axs[2][instrument_idx].set_xlabel(xlabel=label,fontsize=40)
        
        for metric_idx, metric_creator in enumerate(self.metrics):
            label = metric_creator.label
            axs[metric_idx][0].set_ylabel(ylabel=label,fontsize=40)
        
        plt.show()