in tfops.py [0:0]
def actnorm_scale(name, x, scale=1., logdet=None, logscale_factor=3., batch_variance=False, reverse=False, init=False, trainable=True):
shape = x.get_shape()
with tf.variable_scope(name), arg_scope([get_variable_ddi], trainable=trainable):
assert len(shape) == 2 or len(shape) == 4
if len(shape) == 2:
x_var = tf.reduce_mean(x**2, [0], keepdims=True)
logdet_factor = 1
_shape = (1, int_shape(x)[1])
elif len(shape) == 4:
x_var = tf.reduce_mean(x**2, [0, 1, 2], keepdims=True)
logdet_factor = int(shape[1])*int(shape[2])
_shape = (1, 1, 1, int_shape(x)[3])
if batch_variance:
x_var = tf.reduce_mean(x**2, keepdims=True)
if init and False:
# MPI all-reduce
x_var = allreduce_mean(x_var)
# Somehow this also slows down graph when not initializing
# (it's not optimized away?)
if True:
logs = get_variable_ddi("logs", _shape, initial_value=tf.log(
scale/(tf.sqrt(x_var)+1e-6))/logscale_factor)*logscale_factor
if not reverse:
x = x * tf.exp(logs)
else:
x = x * tf.exp(-logs)
else:
# Alternative, doesn't seem to do significantly worse or better than the logarithmic version above
s = get_variable_ddi("s", _shape, initial_value=scale /
(tf.sqrt(x_var) + 1e-6) / logscale_factor)*logscale_factor
logs = tf.log(tf.abs(s))
if not reverse:
x *= s
else:
x /= s
if logdet != None:
dlogdet = tf.reduce_sum(logs) * logdet_factor
if reverse:
dlogdet *= -1
return x, logdet + dlogdet
return x