in model.py [0:0]
def revnet2d_step(name, z, logdet, hps, reverse):
with tf.variable_scope(name):
shape = Z.int_shape(z)
n_z = shape[3]
assert n_z % 2 == 0
if not reverse:
z, logdet = Z.actnorm("actnorm", z, logdet=logdet)
if hps.flow_permutation == 0:
z = Z.reverse_features("reverse", z)
elif hps.flow_permutation == 1:
z = Z.shuffle_features("shuffle", z)
elif hps.flow_permutation == 2:
z, logdet = invertible_1x1_conv("invconv", z, logdet)
else:
raise Exception()
z1 = z[:, :, :, :n_z // 2]
z2 = z[:, :, :, n_z // 2:]
if hps.flow_coupling == 0:
z2 += f("f1", z1, hps.width)
elif hps.flow_coupling == 1:
h = f("f1", z1, hps.width, n_z)
shift = h[:, :, :, 0::2]
# scale = tf.exp(h[:, :, :, 1::2])
scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
z2 += shift
z2 *= scale
logdet += tf.reduce_sum(tf.log(scale), axis=[1, 2, 3])
else:
raise Exception()
z = tf.concat([z1, z2], 3)
else:
z1 = z[:, :, :, :n_z // 2]
z2 = z[:, :, :, n_z // 2:]
if hps.flow_coupling == 0:
z2 -= f("f1", z1, hps.width)
elif hps.flow_coupling == 1:
h = f("f1", z1, hps.width, n_z)
shift = h[:, :, :, 0::2]
# scale = tf.exp(h[:, :, :, 1::2])
scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
z2 /= scale
z2 -= shift
logdet -= tf.reduce_sum(tf.log(scale), axis=[1, 2, 3])
else:
raise Exception()
z = tf.concat([z1, z2], 3)
if hps.flow_permutation == 0:
z = Z.reverse_features("reverse", z, reverse=True)
elif hps.flow_permutation == 1:
z = Z.shuffle_features("shuffle", z, reverse=True)
elif hps.flow_permutation == 2:
z, logdet = invertible_1x1_conv(
"invconv", z, logdet, reverse=True)
else:
raise Exception()
z, logdet = Z.actnorm("actnorm", z, logdet=logdet, reverse=True)
return z, logdet