in gan/utils/metrics_utils.py [0:0]
def plot_metrics(self):
num_instruments = 4
plt.ion()
sns.set()
fig, axs = plt.subplots(len(self.metrics), num_instruments, sharex=True, figsize=(60, 30))
fig.tight_layout()
plt.xscale('log')
for instrument_idx in range(num_instruments):
for metric_idx, metric_creator in enumerate(self.metrics):
axs[metric_idx][instrument_idx].tick_params(axis='both', which='major', labelsize=30)
axs[metric_idx][instrument_idx].tick_params(axis='both', which='minor', labelsize=30)
metric_data = self.metrics[metric_creator]
# Plot reference line
axs[metric_idx][instrument_idx].plot(
[x[0] for x in metric_data['per_iteration']],
np.ones(len(metric_data['per_iteration'])) * metric_data['reference'][instrument_idx],
'r',
linewidth=10,
alpha=0.7
)
# Plot per-iteration metrics
axs[metric_idx][instrument_idx].scatter(
[x[0] for x in metric_data['per_iteration']],
[x[1][instrument_idx] for x in metric_data['per_iteration']],
linewidth=10
)
for instrument_idx in range(num_instruments):
label = "Iterations (Instrument {})".format(instrument_idx)
axs[2][instrument_idx].set_xlabel(xlabel=label,fontsize=40)
for metric_idx, metric_creator in enumerate(self.metrics):
label = metric_creator.label
axs[metric_idx][0].set_ylabel(ylabel=label,fontsize=40)
plt.show()