in model.py [0:0]
def invertible_1x1_conv(name, z, logdet, reverse=False):
if True: # Set to "False" to use the LU-decomposed version
with tf.variable_scope(name):
shape = Z.int_shape(z)
w_shape = [shape[3], shape[3]]
# Sample a random orthogonal matrix:
w_init = np.linalg.qr(np.random.randn(
*w_shape))[0].astype('float32')
w = tf.get_variable("W", dtype=tf.float32, initializer=w_init)
# dlogdet = tf.linalg.LinearOperator(w).log_abs_determinant() * shape[1]*shape[2]
dlogdet = tf.cast(tf.log(abs(tf.matrix_determinant(
tf.cast(w, 'float64')))), 'float32') * shape[1]*shape[2]
if not reverse:
_w = tf.reshape(w, [1, 1] + w_shape)
z = tf.nn.conv2d(z, _w, [1, 1, 1, 1],
'SAME', data_format='NHWC')
logdet += dlogdet
return z, logdet
else:
_w = tf.matrix_inverse(w)
_w = tf.reshape(_w, [1, 1]+w_shape)
z = tf.nn.conv2d(z, _w, [1, 1, 1, 1],
'SAME', data_format='NHWC')
logdet -= dlogdet
return z, logdet
else:
# LU-decomposed version
shape = Z.int_shape(z)
with tf.variable_scope(name):
dtype = 'float64'
# Random orthogonal matrix:
import scipy
np_w = scipy.linalg.qr(np.random.randn(shape[3], shape[3]))[
0].astype('float32')
np_p, np_l, np_u = scipy.linalg.lu(np_w)
np_s = np.diag(np_u)
np_sign_s = np.sign(np_s)
np_log_s = np.log(abs(np_s))
np_u = np.triu(np_u, k=1)
p = tf.get_variable("P", initializer=np_p, trainable=False)
l = tf.get_variable("L", initializer=np_l)
sign_s = tf.get_variable(
"sign_S", initializer=np_sign_s, trainable=False)
log_s = tf.get_variable("log_S", initializer=np_log_s)
# S = tf.get_variable("S", initializer=np_s)
u = tf.get_variable("U", initializer=np_u)
p = tf.cast(p, dtype)
l = tf.cast(l, dtype)
sign_s = tf.cast(sign_s, dtype)
log_s = tf.cast(log_s, dtype)
u = tf.cast(u, dtype)
w_shape = [shape[3], shape[3]]
l_mask = np.tril(np.ones(w_shape, dtype=dtype), -1)
l = l * l_mask + tf.eye(*w_shape, dtype=dtype)
u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s))
w = tf.matmul(p, tf.matmul(l, u))
if True:
u_inv = tf.matrix_inverse(u)
l_inv = tf.matrix_inverse(l)
p_inv = tf.matrix_inverse(p)
w_inv = tf.matmul(u_inv, tf.matmul(l_inv, p_inv))
else:
w_inv = tf.matrix_inverse(w)
w = tf.cast(w, tf.float32)
w_inv = tf.cast(w_inv, tf.float32)
log_s = tf.cast(log_s, tf.float32)
if not reverse:
w = tf.reshape(w, [1, 1] + w_shape)
z = tf.nn.conv2d(z, w, [1, 1, 1, 1],
'SAME', data_format='NHWC')
logdet += tf.reduce_sum(log_s) * (shape[1]*shape[2])
return z, logdet
else:
w_inv = tf.reshape(w_inv, [1, 1]+w_shape)
z = tf.nn.conv2d(
z, w_inv, [1, 1, 1, 1], 'SAME', data_format='NHWC')
logdet -= tf.reduce_sum(log_s) * (shape[1]*shape[2])
return z, logdet