in train.py [0:0]
def init_visualizations(hps, model, logdir):
def sample_batch(y, eps):
n_batch = hps.local_batch_train
xs = []
for i in range(int(np.ceil(len(eps) / n_batch))):
xs.append(model.sample(
y[i*n_batch:i*n_batch + n_batch], eps[i*n_batch:i*n_batch + n_batch]))
return np.concatenate(xs)
def draw_samples(epoch):
if hvd.rank() != 0:
return
rows = 10 if hps.image_size <= 64 else 4
cols = rows
n_batch = rows*cols
y = np.asarray([_y % hps.n_y for _y in (
list(range(cols)) * rows)], dtype='int32')
# temperatures = [0., .25, .5, .626, .75, .875, 1.] #previously
temperatures = [0., .25, .5, .6, .7, .8, .9, 1.]
x_samples = []
x_samples.append(sample_batch(y, [.0]*n_batch))
x_samples.append(sample_batch(y, [.25]*n_batch))
x_samples.append(sample_batch(y, [.5]*n_batch))
x_samples.append(sample_batch(y, [.6]*n_batch))
x_samples.append(sample_batch(y, [.7]*n_batch))
x_samples.append(sample_batch(y, [.8]*n_batch))
x_samples.append(sample_batch(y, [.9] * n_batch))
x_samples.append(sample_batch(y, [1.]*n_batch))
# previously: 0, .25, .5, .625, .75, .875, 1.
for i in range(len(x_samples)):
x_sample = np.reshape(
x_samples[i], (n_batch, hps.image_size, hps.image_size, 3))
graphics.save_raster(x_sample, logdir +
'epoch_{}_sample_{}.png'.format(epoch, i))
return draw_samples