tfops.py (367 lines of code) (raw):
import tensorflow as tf
from tensorflow.contrib.framework.python.ops import add_arg_scope, arg_scope
from tensorflow.contrib.layers import variance_scaling_initializer
import numpy as np
import horovod.tensorflow as hvd
# Debugging function
do_print_act_stats = True
def print_act_stats(x, _str=""):
if not do_print_act_stats:
return x
if hvd.rank() != 0:
return x
if len(x.get_shape()) == 1:
x_mean, x_var = tf.nn.moments(x, [0], keep_dims=True)
if len(x.get_shape()) == 2:
x_mean, x_var = tf.nn.moments(x, [0], keep_dims=True)
if len(x.get_shape()) == 4:
x_mean, x_var = tf.nn.moments(x, [0, 1, 2], keep_dims=True)
stats = [tf.reduce_min(x_mean), tf.reduce_mean(x_mean), tf.reduce_max(x_mean),
tf.reduce_min(tf.sqrt(x_var)), tf.reduce_mean(tf.sqrt(x_var)), tf.reduce_max(tf.sqrt(x_var))]
return tf.Print(x, stats, "["+_str+"] "+x.name)
# Allreduce methods
def allreduce_sum(x):
if hvd.size() == 1:
return x
return hvd.mpi_ops._allreduce(x)
def allreduce_mean(x):
x = allreduce_sum(x) / hvd.size()
return x
def default_initial_value(shape, std=0.05):
return tf.random_normal(shape, 0., std)
def default_initializer(std=0.05):
return tf.random_normal_initializer(0., std)
def int_shape(x):
if str(x.get_shape()[0]) != '?':
return list(map(int, x.get_shape()))
return [-1]+list(map(int, x.get_shape()[1:]))
# wrapper tf.get_variable, augmented with 'init' functionality
# Get variable with data dependent init
@add_arg_scope
def get_variable_ddi(name, shape, initial_value, dtype=tf.float32, init=False, trainable=True):
w = tf.get_variable(name, shape, dtype, None, trainable=trainable)
if init:
w = w.assign(initial_value)
with tf.control_dependencies([w]):
return w
return w
# Activation normalization
# Convenience function that does centering+scaling
@add_arg_scope
def actnorm(name, x, scale=1., logdet=None, logscale_factor=3., batch_variance=False, reverse=False, init=False, trainable=True):
if arg_scope([get_variable_ddi], trainable=trainable):
if not reverse:
x = actnorm_center(name+"_center", x, reverse)
x = actnorm_scale(name+"_scale", x, scale, logdet,
logscale_factor, batch_variance, reverse, init)
if logdet != None:
x, logdet = x
else:
x = actnorm_scale(name + "_scale", x, scale, logdet,
logscale_factor, batch_variance, reverse, init)
if logdet != None:
x, logdet = x
x = actnorm_center(name+"_center", x, reverse)
if logdet != None:
return x, logdet
return x
# Activation normalization
@add_arg_scope
def actnorm_center(name, x, reverse=False):
shape = x.get_shape()
with tf.variable_scope(name):
assert len(shape) == 2 or len(shape) == 4
if len(shape) == 2:
x_mean = tf.reduce_mean(x, [0], keepdims=True)
b = get_variable_ddi(
"b", (1, int_shape(x)[1]), initial_value=-x_mean)
elif len(shape) == 4:
x_mean = tf.reduce_mean(x, [0, 1, 2], keepdims=True)
b = get_variable_ddi(
"b", (1, 1, 1, int_shape(x)[3]), initial_value=-x_mean)
if not reverse:
x += b
else:
x -= b
return x
# Activation normalization
@add_arg_scope
def actnorm_scale(name, x, scale=1., logdet=None, logscale_factor=3., batch_variance=False, reverse=False, init=False, trainable=True):
shape = x.get_shape()
with tf.variable_scope(name), arg_scope([get_variable_ddi], trainable=trainable):
assert len(shape) == 2 or len(shape) == 4
if len(shape) == 2:
x_var = tf.reduce_mean(x**2, [0], keepdims=True)
logdet_factor = 1
_shape = (1, int_shape(x)[1])
elif len(shape) == 4:
x_var = tf.reduce_mean(x**2, [0, 1, 2], keepdims=True)
logdet_factor = int(shape[1])*int(shape[2])
_shape = (1, 1, 1, int_shape(x)[3])
if batch_variance:
x_var = tf.reduce_mean(x**2, keepdims=True)
if init and False:
# MPI all-reduce
x_var = allreduce_mean(x_var)
# Somehow this also slows down graph when not initializing
# (it's not optimized away?)
if True:
logs = get_variable_ddi("logs", _shape, initial_value=tf.log(
scale/(tf.sqrt(x_var)+1e-6))/logscale_factor)*logscale_factor
if not reverse:
x = x * tf.exp(logs)
else:
x = x * tf.exp(-logs)
else:
# Alternative, doesn't seem to do significantly worse or better than the logarithmic version above
s = get_variable_ddi("s", _shape, initial_value=scale /
(tf.sqrt(x_var) + 1e-6) / logscale_factor)*logscale_factor
logs = tf.log(tf.abs(s))
if not reverse:
x *= s
else:
x /= s
if logdet != None:
dlogdet = tf.reduce_sum(logs) * logdet_factor
if reverse:
dlogdet *= -1
return x, logdet + dlogdet
return x
# Linear layer with layer norm
@add_arg_scope
def linear(name, x, width, do_weightnorm=True, do_actnorm=True, initializer=None, scale=1.):
initializer = initializer or default_initializer()
with tf.variable_scope(name):
n_in = int(x.get_shape()[1])
w = tf.get_variable("W", [n_in, width],
tf.float32, initializer=initializer)
if do_weightnorm:
w = tf.nn.l2_normalize(w, [0])
x = tf.matmul(x, w)
x += tf.get_variable("b", [1, width],
initializer=tf.zeros_initializer())
if do_actnorm:
x = actnorm("actnorm", x, scale)
return x
# Linear layer with zero init
@add_arg_scope
def linear_zeros(name, x, width, logscale_factor=3):
with tf.variable_scope(name):
n_in = int(x.get_shape()[1])
w = tf.get_variable("W", [n_in, width], tf.float32,
initializer=tf.zeros_initializer())
x = tf.matmul(x, w)
x += tf.get_variable("b", [1, width],
initializer=tf.zeros_initializer())
x *= tf.exp(tf.get_variable("logs",
[1, width], initializer=tf.zeros_initializer()) * logscale_factor)
return x
# Slow way to add edge padding
def add_edge_padding(x, filter_size):
assert filter_size[0] % 2 == 1
if filter_size[0] == 1 and filter_size[1] == 1:
return x
a = (filter_size[0] - 1) // 2 # vertical padding size
b = (filter_size[1] - 1) // 2 # horizontal padding size
if True:
x = tf.pad(x, [[0, 0], [a, a], [b, b], [0, 0]])
name = "_".join([str(dim) for dim in [a, b, *int_shape(x)[1:3]]])
pads = tf.get_collection(name)
if not pads:
if hvd.rank() == 0:
print("Creating pad", name)
pad = np.zeros([1] + int_shape(x)[1:3] + [1], dtype='float32')
pad[:, :a, :, 0] = 1.
pad[:, -a:, :, 0] = 1.
pad[:, :, :b, 0] = 1.
pad[:, :, -b:, 0] = 1.
pad = tf.convert_to_tensor(pad)
tf.add_to_collection(name, pad)
else:
pad = pads[0]
pad = tf.tile(pad, [tf.shape(x)[0], 1, 1, 1])
x = tf.concat([x, pad], axis=3)
else:
pad = tf.pad(tf.zeros_like(x[:, :, :, :1]) - 1,
[[0, 0], [a, a], [b, b], [0, 0]]) + 1
x = tf.pad(x, [[0, 0], [a, a], [b, b], [0, 0]])
x = tf.concat([x, pad], axis=3)
return x
@add_arg_scope
def conv2d(name, x, width, filter_size=[3, 3], stride=[1, 1], pad="SAME", do_weightnorm=False, do_actnorm=True, context1d=None, skip=1, edge_bias=True):
with tf.variable_scope(name):
if edge_bias and pad == "SAME":
x = add_edge_padding(x, filter_size)
pad = 'VALID'
n_in = int(x.get_shape()[3])
stride_shape = [1] + stride + [1]
filter_shape = filter_size + [n_in, width]
w = tf.get_variable("W", filter_shape, tf.float32,
initializer=default_initializer())
if do_weightnorm:
w = tf.nn.l2_normalize(w, [0, 1, 2])
if skip == 1:
x = tf.nn.conv2d(x, w, stride_shape, pad, data_format='NHWC')
else:
assert stride[0] == 1 and stride[1] == 1
x = tf.nn.atrous_conv2d(x, w, skip, pad)
if do_actnorm:
x = actnorm("actnorm", x)
else:
x += tf.get_variable("b", [1, 1, 1, width],
initializer=tf.zeros_initializer())
if context1d != None:
x += tf.reshape(linear("context", context1d,
width), [-1, 1, 1, width])
return x
@add_arg_scope
def separable_conv2d(name, x, width, filter_size=[3, 3], stride=[1, 1], padding="SAME", do_actnorm=True, std=0.05):
n_in = int(x.get_shape()[3])
with tf.variable_scope(name):
assert filter_size[0] % 2 == 1 and filter_size[1] % 2 == 1
strides = [1] + stride + [1]
w1_shape = filter_size + [n_in, 1]
w1_init = np.zeros(w1_shape, dtype='float32')
w1_init[(filter_size[0]-1)//2, (filter_size[1]-1)//2, :,
:] = 1. # initialize depthwise conv as identity
w1 = tf.get_variable("W1", dtype=tf.float32, initializer=w1_init)
w2_shape = [1, 1, n_in, width]
w2 = tf.get_variable("W2", w2_shape, tf.float32,
initializer=default_initializer(std))
x = tf.nn.separable_conv2d(
x, w1, w2, strides, padding, data_format='NHWC')
if do_actnorm:
x = actnorm("actnorm", x)
else:
x += tf.get_variable("b", [1, 1, 1, width],
initializer=tf.zeros_initializer(std))
return x
@add_arg_scope
def conv2d_zeros(name, x, width, filter_size=[3, 3], stride=[1, 1], pad="SAME", logscale_factor=3, skip=1, edge_bias=True):
with tf.variable_scope(name):
if edge_bias and pad == "SAME":
x = add_edge_padding(x, filter_size)
pad = 'VALID'
n_in = int(x.get_shape()[3])
stride_shape = [1] + stride + [1]
filter_shape = filter_size + [n_in, width]
w = tf.get_variable("W", filter_shape, tf.float32,
initializer=tf.zeros_initializer())
if skip == 1:
x = tf.nn.conv2d(x, w, stride_shape, pad, data_format='NHWC')
else:
assert stride[0] == 1 and stride[1] == 1
x = tf.nn.atrous_conv2d(x, w, skip, pad)
x += tf.get_variable("b", [1, 1, 1, width],
initializer=tf.zeros_initializer())
x *= tf.exp(tf.get_variable("logs",
[1, width], initializer=tf.zeros_initializer()) * logscale_factor)
return x
# 2X nearest-neighbour upsampling, also inspired by Jascha Sohl-Dickstein's code
def upsample2d_nearest_neighbour(x):
shape = x.get_shape()
n_batch = int(shape[0])
height = int(shape[1])
width = int(shape[2])
n_channels = int(shape[3])
x = tf.reshape(x, (n_batch, height, 1, width, 1, n_channels))
x = tf.concat(2, [x, x])
x = tf.concat(4, [x, x])
x = tf.reshape(x, (n_batch, height*2, width*2, n_channels))
return x
def upsample(x, factor=2):
shape = x.get_shape()
height = int(shape[1])
width = int(shape[2])
x = tf.image.resize_nearest_neighbor(x, [height * factor, width * factor])
return x
def squeeze2d(x, factor=2):
assert factor >= 1
if factor == 1:
return x
shape = x.get_shape()
height = int(shape[1])
width = int(shape[2])
n_channels = int(shape[3])
assert height % factor == 0 and width % factor == 0
x = tf.reshape(x, [-1, height//factor, factor,
width//factor, factor, n_channels])
x = tf.transpose(x, [0, 1, 3, 5, 2, 4])
x = tf.reshape(x, [-1, height//factor, width //
factor, n_channels*factor*factor])
return x
def unsqueeze2d(x, factor=2):
assert factor >= 1
if factor == 1:
return x
shape = x.get_shape()
height = int(shape[1])
width = int(shape[2])
n_channels = int(shape[3])
assert n_channels >= 4 and n_channels % 4 == 0
x = tf.reshape(
x, (-1, height, width, int(n_channels/factor**2), factor, factor))
x = tf.transpose(x, [0, 1, 4, 2, 5, 3])
x = tf.reshape(x, (-1, int(height*factor),
int(width*factor), int(n_channels/factor**2)))
return x
# Reverse features across channel dimension
def reverse_features(name, h, reverse=False):
return h[:, :, :, ::-1]
# Shuffle across the channel dimension
def shuffle_features(name, h, indices=None, return_indices=False, reverse=False):
with tf.variable_scope(name):
rng = np.random.RandomState(
(abs(hash(tf.get_variable_scope().name))) % 10000000)
if indices == None:
# Create numpy and tensorflow variables with indices
n_channels = int(h.get_shape()[-1])
indices = list(range(n_channels))
rng.shuffle(indices)
# Reverse it
indices_inverse = [0]*n_channels
for i in range(n_channels):
indices_inverse[indices[i]] = i
tf_indices = tf.get_variable("indices", dtype=tf.int32, initializer=np.asarray(
indices, dtype='int32'), trainable=False)
tf_indices_reverse = tf.get_variable("indices_inverse", dtype=tf.int32, initializer=np.asarray(
indices_inverse, dtype='int32'), trainable=False)
_indices = tf_indices
if reverse:
_indices = tf_indices_reverse
if len(h.get_shape()) == 2:
# Slice
h = tf.transpose(h)
h = tf.gather(h, _indices)
h = tf.transpose(h)
elif len(h.get_shape()) == 4:
# Slice
h = tf.transpose(h, [3, 1, 2, 0])
h = tf.gather(h, _indices)
h = tf.transpose(h, [3, 1, 2, 0])
if return_indices:
return h, indices
return h
def embedding(name, y, n_y, width):
with tf.variable_scope(name):
params = tf.get_variable(
"embedding", [n_y, width], initializer=default_initializer())
embeddings = tf.gather(params, y)
return embeddings
# Random variables
def flatten_sum(logps):
if len(logps.get_shape()) == 2:
return tf.reduce_sum(logps, [1])
elif len(logps.get_shape()) == 4:
return tf.reduce_sum(logps, [1, 2, 3])
else:
raise Exception()
def standard_gaussian(shape):
return gaussian_diag(tf.zeros(shape), tf.zeros(shape))
def gaussian_diag(mean, logsd):
class o(object):
pass
o.mean = mean
o.logsd = logsd
o.eps = tf.random_normal(tf.shape(mean))
o.sample = mean + tf.exp(logsd) * o.eps
o.sample2 = lambda eps: mean + tf.exp(logsd) * eps
o.logps = lambda x: -0.5 * \
(np.log(2 * np.pi) + 2. * logsd + (x - mean) ** 2 / tf.exp(2. * logsd))
o.logp = lambda x: flatten_sum(o.logps(x))
o.get_eps = lambda x: (x - mean) / tf.exp(logsd)
return o
# def discretized_logistic_old(mean, logscale, binsize=1 / 256.0, sample=None):
# scale = tf.exp(logscale)
# sample = (tf.floor(sample / binsize) * binsize - mean) / scale
# logp = tf.log(tf.sigmoid(sample + binsize / scale) - tf.sigmoid(sample) + 1e-7)
# return tf.reduce_sum(logp, [1, 2, 3])
def discretized_logistic(mean, logscale, binsize=1. / 256):
class o(object):
pass
o.mean = mean
o.logscale = logscale
scale = tf.exp(logscale)
def logps(x):
x = (x - mean) / scale
return tf.log(tf.sigmoid(x + binsize / scale) - tf.sigmoid(x) + 1e-7)
o.logps = logps
o.logp = lambda x: flatten_sum(logps(x))
return o
def _symmetric_matrix_square_root(mat, eps=1e-10):
"""Compute square root of a symmetric matrix.
Note that this is different from an elementwise square root. We want to
compute M' where M' = sqrt(mat) such that M' * M' = mat.
Also note that this method **only** works for symmetric matrices.
Args:
mat: Matrix to take the square root of.
eps: Small epsilon such that any element less than eps will not be square
rooted to guard against numerical instability.
Returns:
Matrix square root of mat.
"""
# Unlike numpy, tensorflow's return order is (s, u, v)
s, u, v = tf.svd(mat)
# sqrt is unstable around 0, just use 0 in such case
si = tf.where(tf.less(s, eps), s, tf.sqrt(s))
# Note that the v returned by Tensorflow is v = V
# (when referencing the equation A = U S V^T)
# This is unlike Numpy which returns v = V^T
return tf.matmul(
tf.matmul(u, tf.diag(si)), v, transpose_b=True)