in reproduce/plot_variance_ratio.py [0:0]
def plot_variance_ratios(plot_name, data_files_grob, xvals):
ylabel = "SVR Variance / SGD Variance"
keys_to_show = ['2%', '11%', '33%', '100%']
# Position of in-plot labels along the x axis
epochs = []
ratios = []
vr_variances = []
gradient_variances = []
trace_data = {}
data_files = glob.glob(data_files_grob)
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
return [ atoi(c) for c in re.split('(\d+)', text) ]
data_files.sort(key=natural_keys)
#pdb.set_trace()
for fname in data_files:
print("(ALL) processing file ", fname)
with open(fname, 'rb') as fdata:
rd = pickle.load(fdata)
#pdb.set_trace()
if 'batch_indices' in rd:
print("Has batch indices")
# Calculate x axis for plotting
batch_indices = np.array(rd["batch_indices"])
nk = len(batch_indices)
if max(batch_indices) == min(batch_indices):
eval_points = np.array(range(nk))/nk
else:
eval_points = batch_indices/max(batch_indices)
epochs.append(rd["epoch"])
#pdb.set_trace()
ratio_points = (np.array(rd["vr_step_variances"])/np.array(rd["gradient_variances"])).tolist()
for i, ep in enumerate(eval_points):
ep_name = "{0:.0f}%".format(100*ep)
if ep_name not in trace_data.keys():
trace_data[ep_name] = [ratio_points[i]]
else:
trace_data[ep_name].append(ratio_points[i])
plt.cla()
fig = plt.figure(figsize=(3.2,2))
ax = fig.add_subplot(111)
ax.set_prop_cycle("color", colors)
#pdb.set_trace()
for ep_name, data in trace_data.items():
if ep_name in keys_to_show:
ax.plot(epochs, data, ".",
label=ep_name) #, linestyle=next(linestyles))
if ep_name == "100%":
print("100p epochs:", epochs)
print("ratios: ", data)
print("Finalizing plot")
plt.xlabel('Epoch')
plt.ylabel(ylabel)
ax.set_yscale("log", basey=2)
ax.set_yticks([2**(-i) for i in range(0, 11)])
plt.ylim([1e-3, 3])
plt.xlim([0.0, 240])
# Horizontal line at 1
#plt.axhline(y=1.0, color="#000000", linestyle='--')
#plt.axhline(y=2.0, color="#000000", linestyle='--')
ax.axhspan(1, 2, alpha=0.3, facecolor='red', edgecolor=None)
# Step size reduction indicators
plt.axvline(x=150.0, color="brown", linestyle='--')
plt.axvline(x=220.0, color="brown", linestyle='--')
#loc = plticker.LogLocator(base=2.0)
#ax.yaxis.set_major_locator(loc)
#plt.tick_params(axis='y', which='minor')
ax.grid(False)
ax.xaxis.set_tick_params(direction='in')
ax.yaxis.set_tick_params(direction='in', right="on")
labelLines(plt.gca().get_lines(), align=False, fontsize=label_fontsize, xvals=xvals)
figname = "{}/{}.pdf".format(plot_dir, plot_name)
fig.savefig(figname, bbox_inches='tight', pad_inches=0)
print("saved", figname)