in model.py [0:0]
def model(sess, hps, train_iterator, test_iterator, data_init):
# Only for decoding/init, rest use iterators directly
with tf.name_scope('input'):
X = tf.placeholder(
tf.uint8, [None, hps.image_size, hps.image_size, 3], name='image')
Y = tf.placeholder(tf.int32, [None], name='label')
lr = tf.placeholder(tf.float32, None, name='learning_rate')
encoder, decoder = codec(hps)
hps.n_bins = 2. ** hps.n_bits_x
def preprocess(x):
x = tf.cast(x, 'float32')
if hps.n_bits_x < 8:
x = tf.floor(x / 2 ** (8 - hps.n_bits_x))
x = x / hps.n_bins - .5
return x
def postprocess(x):
return tf.cast(tf.clip_by_value(tf.floor((x + .5)*hps.n_bins)*(256./hps.n_bins), 0, 255), 'uint8')
def _f_loss(x, y, is_training, reuse=False):
with tf.variable_scope('model', reuse=reuse):
y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')
# Discrete -> Continuous
objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0]
z = preprocess(x)
z = z + tf.random_uniform(tf.shape(z), 0, 1./hps.n_bins)
objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:])
# Encode
z = Z.squeeze2d(z, 2) # > 16x16x12
z, objective, _ = encoder(z, objective)
# Prior
hps.top_shape = Z.int_shape(z)[1:]
logp, _, _ = prior("prior", y_onehot, hps)
objective += logp(z)
# Generative loss
nobj = - objective
bits_x = nobj / (np.log(2.) * int(x.get_shape()[1]) * int(
x.get_shape()[2]) * int(x.get_shape()[3])) # bits per subpixel
# Predictive loss
if hps.weight_y > 0 and hps.ycond:
# Classification loss
h_y = tf.reduce_mean(z, axis=[1, 2])
y_logits = Z.linear_zeros("classifier", h_y, hps.n_y)
bits_y = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=y_onehot, logits=y_logits) / np.log(2.)
# Classification accuracy
y_predicted = tf.argmax(y_logits, 1, output_type=tf.int32)
classification_error = 1 - \
tf.cast(tf.equal(y_predicted, y), tf.float32)
else:
bits_y = tf.zeros_like(bits_x)
classification_error = tf.ones_like(bits_x)
return bits_x, bits_y, classification_error
def f_loss(iterator, is_training, reuse=False):
if hps.direct_iterator and iterator is not None:
x, y = iterator.get_next()
else:
x, y = X, Y
bits_x, bits_y, pred_loss = _f_loss(x, y, is_training, reuse)
local_loss = bits_x + hps.weight_y * bits_y
stats = [local_loss, bits_x, bits_y, pred_loss]
global_stats = Z.allreduce_mean(
tf.stack([tf.reduce_mean(i) for i in stats]))
return tf.reduce_mean(local_loss), global_stats
feeds = {'x': X, 'y': Y}
m = abstract_model_xy(sess, hps, feeds, train_iterator,
test_iterator, data_init, lr, f_loss)
# === Sampling function
def f_sample(y, eps_std):
with tf.variable_scope('model', reuse=True):
y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')
_, sample, _ = prior("prior", y_onehot, hps)
z = sample(eps_std=eps_std)
z = decoder(z, eps_std=eps_std)
z = Z.unsqueeze2d(z, 2) # 8x8x12 -> 16x16x3
x = postprocess(z)
return x
m.eps_std = tf.placeholder(tf.float32, [None], name='eps_std')
x_sampled = f_sample(Y, m.eps_std)
def sample(_y, _eps_std):
return m.sess.run(x_sampled, {Y: _y, m.eps_std: _eps_std})
m.sample = sample
if hps.inference:
# === Encoder-Decoder functions
def f_encode(x, y, reuse=True):
with tf.variable_scope('model', reuse=reuse):
y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')
# Discrete -> Continuous
objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0]
z = preprocess(x)
z = z + tf.random_uniform(tf.shape(z), 0, 1. / hps.n_bins)
objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:])
# Encode
z = Z.squeeze2d(z, 2) # > 16x16x12
z, objective, eps = encoder(z, objective)
# Prior
hps.top_shape = Z.int_shape(z)[1:]
logp, _, _eps = prior("prior", y_onehot, hps)
objective += logp(z)
eps.append(_eps(z))
return eps
def f_decode(y, eps, reuse=True):
with tf.variable_scope('model', reuse=reuse):
y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')
_, sample, _ = prior("prior", y_onehot, hps)
z = sample(eps=eps[-1])
z = decoder(z, eps=eps[:-1])
z = Z.unsqueeze2d(z, 2) # 8x8x12 -> 16x16x3
x = postprocess(z)
return x
enc_eps = f_encode(X, Y)
dec_eps = []
print(enc_eps)
for i, _eps in enumerate(enc_eps):
print(_eps)
dec_eps.append(tf.placeholder(tf.float32, _eps.get_shape().as_list(), name="dec_eps_" + str(i)))
dec_x = f_decode(Y, dec_eps)
eps_shapes = [_eps.get_shape().as_list()[1:] for _eps in enc_eps]
def flatten_eps(eps):
# [BS, eps_size]
return np.concatenate([np.reshape(e, (e.shape[0], -1)) for e in eps], axis=-1)
def unflatten_eps(feps):
index = 0
eps = []
bs = feps.shape[0]
for shape in eps_shapes:
eps.append(np.reshape(feps[:, index: index+np.prod(shape)], (bs, *shape)))
index += np.prod(shape)
return eps
# If model is uncondtional, always pass y = np.zeros([bs], dtype=np.int32)
def encode(x, y):
return flatten_eps(sess.run(enc_eps, {X: x, Y: y}))
def decode(y, feps):
eps = unflatten_eps(feps)
feed_dict = {Y: y}
for i in range(len(dec_eps)):
feed_dict[dec_eps[i]] = eps[i]
return sess.run(dec_x, feed_dict)
m.encode = encode
m.decode = decode
return m