model.py (420 lines of code) (raw):

import tensorflow as tf import tfops as Z import optim import numpy as np import horovod.tensorflow as hvd from tensorflow.contrib.framework.python.ops import add_arg_scope ''' f_loss: function with as input the (x,y,reuse=False), and as output a list/tuple whose first element is the loss. ''' def abstract_model_xy(sess, hps, feeds, train_iterator, test_iterator, data_init, lr, f_loss): # == Create class with static fields and methods class m(object): pass m.sess = sess m.feeds = feeds m.lr = lr # === Loss and optimizer loss_train, stats_train = f_loss(train_iterator, True) all_params = tf.trainable_variables() if hps.gradient_checkpointing == 1: from memory_saving_gradients import gradients gs = gradients(loss_train, all_params) else: gs = tf.gradients(loss_train, all_params) optimizer = {'adam': optim.adam, 'adamax': optim.adamax, 'adam2': optim.adam2}[hps.optimizer] train_op, polyak_swap_op, ema = optimizer( all_params, gs, alpha=lr, hps=hps) if hps.direct_iterator: m.train = lambda _lr: sess.run([train_op, stats_train], {lr: _lr})[1] else: def _train(_lr): _x, _y = train_iterator() return sess.run([train_op, stats_train], {feeds['x']: _x, feeds['y']: _y, lr: _lr})[1] m.train = _train m.polyak_swap = lambda: sess.run(polyak_swap_op) # === Testing loss_test, stats_test = f_loss(test_iterator, False, reuse=True) if hps.direct_iterator: m.test = lambda: sess.run(stats_test) else: def _test(): _x, _y = test_iterator() return sess.run(stats_test, {feeds['x']: _x, feeds['y']: _y}) m.test = _test # === Saving and restoring saver = tf.train.Saver() saver_ema = tf.train.Saver(ema.variables_to_restore()) m.save_ema = lambda path: saver_ema.save( sess, path, write_meta_graph=False) m.save = lambda path: saver.save(sess, path, write_meta_graph=False) m.restore = lambda path: saver.restore(sess, path) # === Initialize the parameters if hps.restore_path != '': m.restore(hps.restore_path) else: with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True): results_init = f_loss(None, True, reuse=True) sess.run(tf.global_variables_initializer()) sess.run(results_init, {feeds['x']: data_init['x'], feeds['y']: data_init['y']}) sess.run(hvd.broadcast_global_variables(0)) return m def codec(hps): def encoder(z, objective): eps = [] for i in range(hps.n_levels): z, objective = revnet2d(str(i), z, objective, hps) if i < hps.n_levels-1: z, objective, _eps = split2d("pool"+str(i), z, objective=objective) eps.append(_eps) return z, objective, eps def decoder(z, eps=[None]*hps.n_levels, eps_std=None): for i in reversed(range(hps.n_levels)): if i < hps.n_levels-1: z = split2d_reverse("pool"+str(i), z, eps=eps[i], eps_std=eps_std) z, _ = revnet2d(str(i), z, 0, hps, reverse=True) return z return encoder, decoder def prior(name, y_onehot, hps): with tf.variable_scope(name): n_z = hps.top_shape[-1] h = tf.zeros([tf.shape(y_onehot)[0]]+hps.top_shape[:2]+[2*n_z]) if hps.learntop: h = Z.conv2d_zeros('p', h, 2*n_z) if hps.ycond: h += tf.reshape(Z.linear_zeros("y_emb", y_onehot, 2*n_z), [-1, 1, 1, 2 * n_z]) pz = Z.gaussian_diag(h[:, :, :, :n_z], h[:, :, :, n_z:]) def logp(z1): objective = pz.logp(z1) return objective def sample(eps=None, eps_std=None): if eps is not None: # Already sampled eps. Don't use eps_std z = pz.sample2(eps) elif eps_std is not None: # Sample with given eps_std z = pz.sample2(pz.eps * tf.reshape(eps_std, [-1, 1, 1, 1])) else: # Sample normally z = pz.sample return z def eps(z1): return pz.get_eps(z1) return logp, sample, eps def model(sess, hps, train_iterator, test_iterator, data_init): # Only for decoding/init, rest use iterators directly with tf.name_scope('input'): X = tf.placeholder( tf.uint8, [None, hps.image_size, hps.image_size, 3], name='image') Y = tf.placeholder(tf.int32, [None], name='label') lr = tf.placeholder(tf.float32, None, name='learning_rate') encoder, decoder = codec(hps) hps.n_bins = 2. ** hps.n_bits_x def preprocess(x): x = tf.cast(x, 'float32') if hps.n_bits_x < 8: x = tf.floor(x / 2 ** (8 - hps.n_bits_x)) x = x / hps.n_bins - .5 return x def postprocess(x): return tf.cast(tf.clip_by_value(tf.floor((x + .5)*hps.n_bins)*(256./hps.n_bins), 0, 255), 'uint8') def _f_loss(x, y, is_training, reuse=False): with tf.variable_scope('model', reuse=reuse): y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') # Discrete -> Continuous objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0] z = preprocess(x) z = z + tf.random_uniform(tf.shape(z), 0, 1./hps.n_bins) objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:]) # Encode z = Z.squeeze2d(z, 2) # > 16x16x12 z, objective, _ = encoder(z, objective) # Prior hps.top_shape = Z.int_shape(z)[1:] logp, _, _ = prior("prior", y_onehot, hps) objective += logp(z) # Generative loss nobj = - objective bits_x = nobj / (np.log(2.) * int(x.get_shape()[1]) * int( x.get_shape()[2]) * int(x.get_shape()[3])) # bits per subpixel # Predictive loss if hps.weight_y > 0 and hps.ycond: # Classification loss h_y = tf.reduce_mean(z, axis=[1, 2]) y_logits = Z.linear_zeros("classifier", h_y, hps.n_y) bits_y = tf.nn.softmax_cross_entropy_with_logits_v2( labels=y_onehot, logits=y_logits) / np.log(2.) # Classification accuracy y_predicted = tf.argmax(y_logits, 1, output_type=tf.int32) classification_error = 1 - \ tf.cast(tf.equal(y_predicted, y), tf.float32) else: bits_y = tf.zeros_like(bits_x) classification_error = tf.ones_like(bits_x) return bits_x, bits_y, classification_error def f_loss(iterator, is_training, reuse=False): if hps.direct_iterator and iterator is not None: x, y = iterator.get_next() else: x, y = X, Y bits_x, bits_y, pred_loss = _f_loss(x, y, is_training, reuse) local_loss = bits_x + hps.weight_y * bits_y stats = [local_loss, bits_x, bits_y, pred_loss] global_stats = Z.allreduce_mean( tf.stack([tf.reduce_mean(i) for i in stats])) return tf.reduce_mean(local_loss), global_stats feeds = {'x': X, 'y': Y} m = abstract_model_xy(sess, hps, feeds, train_iterator, test_iterator, data_init, lr, f_loss) # === Sampling function def f_sample(y, eps_std): with tf.variable_scope('model', reuse=True): y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') _, sample, _ = prior("prior", y_onehot, hps) z = sample(eps_std=eps_std) z = decoder(z, eps_std=eps_std) z = Z.unsqueeze2d(z, 2) # 8x8x12 -> 16x16x3 x = postprocess(z) return x m.eps_std = tf.placeholder(tf.float32, [None], name='eps_std') x_sampled = f_sample(Y, m.eps_std) def sample(_y, _eps_std): return m.sess.run(x_sampled, {Y: _y, m.eps_std: _eps_std}) m.sample = sample if hps.inference: # === Encoder-Decoder functions def f_encode(x, y, reuse=True): with tf.variable_scope('model', reuse=reuse): y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') # Discrete -> Continuous objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0] z = preprocess(x) z = z + tf.random_uniform(tf.shape(z), 0, 1. / hps.n_bins) objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:]) # Encode z = Z.squeeze2d(z, 2) # > 16x16x12 z, objective, eps = encoder(z, objective) # Prior hps.top_shape = Z.int_shape(z)[1:] logp, _, _eps = prior("prior", y_onehot, hps) objective += logp(z) eps.append(_eps(z)) return eps def f_decode(y, eps, reuse=True): with tf.variable_scope('model', reuse=reuse): y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') _, sample, _ = prior("prior", y_onehot, hps) z = sample(eps=eps[-1]) z = decoder(z, eps=eps[:-1]) z = Z.unsqueeze2d(z, 2) # 8x8x12 -> 16x16x3 x = postprocess(z) return x enc_eps = f_encode(X, Y) dec_eps = [] print(enc_eps) for i, _eps in enumerate(enc_eps): print(_eps) dec_eps.append(tf.placeholder(tf.float32, _eps.get_shape().as_list(), name="dec_eps_" + str(i))) dec_x = f_decode(Y, dec_eps) eps_shapes = [_eps.get_shape().as_list()[1:] for _eps in enc_eps] def flatten_eps(eps): # [BS, eps_size] return np.concatenate([np.reshape(e, (e.shape[0], -1)) for e in eps], axis=-1) def unflatten_eps(feps): index = 0 eps = [] bs = feps.shape[0] for shape in eps_shapes: eps.append(np.reshape(feps[:, index: index+np.prod(shape)], (bs, *shape))) index += np.prod(shape) return eps # If model is uncondtional, always pass y = np.zeros([bs], dtype=np.int32) def encode(x, y): return flatten_eps(sess.run(enc_eps, {X: x, Y: y})) def decode(y, feps): eps = unflatten_eps(feps) feed_dict = {Y: y} for i in range(len(dec_eps)): feed_dict[dec_eps[i]] = eps[i] return sess.run(dec_x, feed_dict) m.encode = encode m.decode = decode return m def checkpoint(z, logdet): zshape = Z.int_shape(z) z = tf.reshape(z, [-1, zshape[1]*zshape[2]*zshape[3]]) logdet = tf.reshape(logdet, [-1, 1]) combined = tf.concat([z, logdet], axis=1) tf.add_to_collection('checkpoints', combined) logdet = combined[:, -1] z = tf.reshape(combined[:, :-1], [-1, zshape[1], zshape[2], zshape[3]]) return z, logdet @add_arg_scope def revnet2d(name, z, logdet, hps, reverse=False): with tf.variable_scope(name): if not reverse: for i in range(hps.depth): z, logdet = checkpoint(z, logdet) z, logdet = revnet2d_step(str(i), z, logdet, hps, reverse) z, logdet = checkpoint(z, logdet) else: for i in reversed(range(hps.depth)): z, logdet = revnet2d_step(str(i), z, logdet, hps, reverse) return z, logdet # Simpler, new version @add_arg_scope def revnet2d_step(name, z, logdet, hps, reverse): with tf.variable_scope(name): shape = Z.int_shape(z) n_z = shape[3] assert n_z % 2 == 0 if not reverse: z, logdet = Z.actnorm("actnorm", z, logdet=logdet) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv("invconv", z, logdet) else: raise Exception() z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 += f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) z2 += shift z2 *= scale logdet += tf.reduce_sum(tf.log(scale), axis=[1, 2, 3]) else: raise Exception() z = tf.concat([z1, z2], 3) else: z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 -= f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) z2 /= scale z2 -= shift logdet -= tf.reduce_sum(tf.log(scale), axis=[1, 2, 3]) else: raise Exception() z = tf.concat([z1, z2], 3) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z, reverse=True) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z, reverse=True) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv( "invconv", z, logdet, reverse=True) else: raise Exception() z, logdet = Z.actnorm("actnorm", z, logdet=logdet, reverse=True) return z, logdet def f(name, h, width, n_out=None): n_out = n_out or int(h.get_shape()[3]) with tf.variable_scope(name): h = tf.nn.relu(Z.conv2d("l_1", h, width)) h = tf.nn.relu(Z.conv2d("l_2", h, width, filter_size=[1, 1])) h = Z.conv2d_zeros("l_last", h, n_out) return h def f_resnet(name, h, width, n_out=None): n_out = n_out or int(h.get_shape()[3]) with tf.variable_scope(name): h = tf.nn.relu(Z.conv2d("l_1", h, width)) h = Z.conv2d_zeros("l_2", h, n_out) return h # Invertible 1x1 conv @add_arg_scope def invertible_1x1_conv(name, z, logdet, reverse=False): if True: # Set to "False" to use the LU-decomposed version with tf.variable_scope(name): shape = Z.int_shape(z) w_shape = [shape[3], shape[3]] # Sample a random orthogonal matrix: w_init = np.linalg.qr(np.random.randn( *w_shape))[0].astype('float32') w = tf.get_variable("W", dtype=tf.float32, initializer=w_init) # dlogdet = tf.linalg.LinearOperator(w).log_abs_determinant() * shape[1]*shape[2] dlogdet = tf.cast(tf.log(abs(tf.matrix_determinant( tf.cast(w, 'float64')))), 'float32') * shape[1]*shape[2] if not reverse: _w = tf.reshape(w, [1, 1] + w_shape) z = tf.nn.conv2d(z, _w, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet += dlogdet return z, logdet else: _w = tf.matrix_inverse(w) _w = tf.reshape(_w, [1, 1]+w_shape) z = tf.nn.conv2d(z, _w, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet -= dlogdet return z, logdet else: # LU-decomposed version shape = Z.int_shape(z) with tf.variable_scope(name): dtype = 'float64' # Random orthogonal matrix: import scipy np_w = scipy.linalg.qr(np.random.randn(shape[3], shape[3]))[ 0].astype('float32') np_p, np_l, np_u = scipy.linalg.lu(np_w) np_s = np.diag(np_u) np_sign_s = np.sign(np_s) np_log_s = np.log(abs(np_s)) np_u = np.triu(np_u, k=1) p = tf.get_variable("P", initializer=np_p, trainable=False) l = tf.get_variable("L", initializer=np_l) sign_s = tf.get_variable( "sign_S", initializer=np_sign_s, trainable=False) log_s = tf.get_variable("log_S", initializer=np_log_s) # S = tf.get_variable("S", initializer=np_s) u = tf.get_variable("U", initializer=np_u) p = tf.cast(p, dtype) l = tf.cast(l, dtype) sign_s = tf.cast(sign_s, dtype) log_s = tf.cast(log_s, dtype) u = tf.cast(u, dtype) w_shape = [shape[3], shape[3]] l_mask = np.tril(np.ones(w_shape, dtype=dtype), -1) l = l * l_mask + tf.eye(*w_shape, dtype=dtype) u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s)) w = tf.matmul(p, tf.matmul(l, u)) if True: u_inv = tf.matrix_inverse(u) l_inv = tf.matrix_inverse(l) p_inv = tf.matrix_inverse(p) w_inv = tf.matmul(u_inv, tf.matmul(l_inv, p_inv)) else: w_inv = tf.matrix_inverse(w) w = tf.cast(w, tf.float32) w_inv = tf.cast(w_inv, tf.float32) log_s = tf.cast(log_s, tf.float32) if not reverse: w = tf.reshape(w, [1, 1] + w_shape) z = tf.nn.conv2d(z, w, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet += tf.reduce_sum(log_s) * (shape[1]*shape[2]) return z, logdet else: w_inv = tf.reshape(w_inv, [1, 1]+w_shape) z = tf.nn.conv2d( z, w_inv, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet -= tf.reduce_sum(log_s) * (shape[1]*shape[2]) return z, logdet @add_arg_scope def split2d(name, z, objective=0.): with tf.variable_scope(name): n_z = Z.int_shape(z)[3] z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] pz = split2d_prior(z1) objective += pz.logp(z2) z1 = Z.squeeze2d(z1) eps = pz.get_eps(z2) return z1, objective, eps @add_arg_scope def split2d_reverse(name, z, eps, eps_std): with tf.variable_scope(name): z1 = Z.unsqueeze2d(z) pz = split2d_prior(z1) if eps is not None: # Already sampled eps z2 = pz.sample2(eps) elif eps_std is not None: # Sample with given eps_std z2 = pz.sample2(pz.eps * tf.reshape(eps_std, [-1, 1, 1, 1])) else: # Sample normally z2 = pz.sample z = tf.concat([z1, z2], 3) return z @add_arg_scope def split2d_prior(z): n_z2 = int(z.get_shape()[3]) n_z1 = n_z2 h = Z.conv2d_zeros("conv", z, 2 * n_z1) mean = h[:, :, :, 0::2] logs = h[:, :, :, 1::2] return Z.gaussian_diag(mean, logs)