in tensor2tensor/layers/discretization.py [0:0]
def discrete_bottleneck(inputs,
hidden_size,
z_size,
filter_size,
mode=None,
bottleneck_kind="dvq",
num_blocks=2,
num_residuals=1,
reshape_method="slice",
projection_tensors=None,
beta=0.25,
ema=True,
means=None,
ema_count=None,
ema_means=None,
epsilon=1e-5,
decay=0.999,
random_top_k=1,
soft_em=False,
num_samples=1,
softmax_k=0,
temperature_warmup_steps=150000,
do_hard_gumbel_softmax=False,
num_flows=0,
approximate_gs_entropy=False,
sum_over_latents=False,
discrete_mix=0.5,
noise_dev=1.,
startup_steps=50000,
summary=True,
name=None,
cond=True):
"""Discretization bottleneck.
Args:
inputs: Input to the bottleneck, a Tensor of shape [..., channels].
hidden_size: Dimension of the dense output.
z_size: Number of bits, where discrete codes range from 1 to 2**z_size.
filter_size: Filter size in the embedding function.
mode: tf.estimator.ModeKeys.
bottleneck_kind: Kind of discretization bottleneck. One of dense, dvq
(decomposed vector quantization), gumbel-softmax, gumbel-softmax-dvq,
semhash, or vae.
num_blocks: Number of blocks. Used only if bottleneck_kind is DVQ.
num_residuals: Number of residual units used to compute nearest
neighbors. Used only if bottleneck_kind is DVQ.
reshape_method: Method to reshape. Used only if bottleneck_kind is DVQ.
projection_tensors: If the reshape method is project, then these are the
tensors used to project.
beta: Scale factor for codebook loss and EMA. Used only if bottleneck_kind
is DVQ.
ema: Whether to update embeddings using exponential moving averages. Used
only if bottleneck_kind is DVQ.
means: The embedding table. Used only if ema is True.
ema_count: Table of counts for each embedding corresponding to how many
examples in a batch it was the closest to. Used only if ema is True.
ema_means: Exponentially averaged version of the embeddings. Used only if
ema is True.
epsilon: Small value to avoid dividing by zero in EMA update. Used only if
ema is True.
decay: Decay factor for the exponential moving average. Used only if ema is
True.
random_top_k: Noisy top-k. Used only if bottleneck_kind is DVQ.
soft_em: Whether to use soft EM or hard EM. Used only if bottleneck_kind is
DVQ.
num_samples: Number of samples for soft EM. Used only if soft_em is True.
softmax_k: If > 0 then do top-k softmax. Used only if bottleneck_kind
is gumbel-softmax.
temperature_warmup_steps: Number of steps it takes to decay temperature to
0. Used only if bottleneck_kind is gumbel-softmax or gumbel-softmax-dvq.
do_hard_gumbel_softmax: Whether to use hard or soft Gumbel-Softmax
samples. Used only if bottleneck_kind is gumbel-softmax-dvq.
num_flows: Number of inverse autoregresive flows. Used only if
bottleneck_kind is gumbel-softmax-dvq.
approximate_gs_entropy: Whether to approximate the Gumbel-Softmax density
as a categorical distribution when calculating the sample entropy. Used
only if bottleneck_kind is gumbel-softmax-dvq.
sum_over_latents: Whether to sum over all non-batch dimensions before
taking mean of entropy loss term. Used only if bottleneck kind is DVQ
or gumbel-softmax-dvq.
discrete_mix: Factor for mixing discrete and non-discrete input. Used only
if bottleneck_kind is semhash.
noise_dev: Noise stddev. Used only if bottleneck_kind is semhash.
startup_steps: Number of steps after which latent predictor is trained. Used
only if bottleneck_kind is semhash.
summary: Whether to write summaries.
name: Name for the bottleneck scope.
cond: A tf.bool condition on whether to update the codebook.
Returns:
outputs_dense: Tensor of shape [..., output_dim]. The output dimension is
hidden_size if bottleneck_kind is gumbel-softmax, DVQ; filter_size if
bottleneck_kind is dense, semhash, vae. If bottleneck_kind is DVQ,
outputs_dense represents the codebook (means) indexed by outputs_discrete.
outputs_discrete: Tensor of shape [...]. Discrete codes, each an index in
[0, 2**z_size). It uses the hot representation if soft_em is True.
extra_loss: Scalar Tensor. Sum of codebook and commitment losses if
bottleneck_kind is DVQ; else zero.
embed_fn: Function embed with arguments partially filled in.
neg_q_entropy: Scalar Tensor representing negative entropy of variational
approximation (0 if it is deterministic).
Raises:
ValueError: If projection_tensors is None for reshape_method project, or
ema_count or ema_means is None if ema is True, or unknown args.
"""
if bottleneck_kind in ["dvq", "gumbel-softmax-dvq"]:
assert means is not None
if hidden_size % num_blocks != 0:
raise ValueError("num_blocks does not divide hidden size")
if z_size % num_residuals != 0:
raise ValueError("num_residuals does not divide embedding table size")
z_size_per_residual = int(z_size / num_residuals)
if z_size_per_residual % num_blocks != 0:
raise ValueError("num_blocks does not divide embedding table size")
block_v_size = 2**int(z_size_per_residual / num_blocks)
if ema:
if ema_count is None:
raise ValueError("ema_count is None but ema is True")
if ema_means is None:
raise ValueError("ema_means is None but ema is True")
else:
block_v_size = None
with tf.variable_scope(
name, default_name="discrete_bottleneck", reuse=tf.AUTO_REUSE):
embed_fn = partial(
embed,
hidden_size=hidden_size,
z_size=z_size,
filter_size=filter_size,
bottleneck_kind=bottleneck_kind,
soft_em=soft_em,
num_blocks=num_blocks,
num_residuals=num_residuals,
block_v_size=block_v_size,
means=means,
name=name)
if bottleneck_kind == "dense":
# Note discrete output is continuous here.
outputs_discrete = tf.layers.dense(inputs, z_size, name="vcc")
outputs_dense = tf.layers.dense(
outputs_discrete, filter_size, name="vch1")
extra_loss = tf.constant(0.0)
neg_q_entropy = tf.constant(0.0)
elif bottleneck_kind in ["dvq", "gumbel-softmax-dvq"]:
inputs_3d = inputs
if len(inputs.shape) == 4:
inputs_3d = tf.squeeze(inputs, axis=2)
if reshape_method == "slice":
x_reshaped = slice_hidden(
inputs_3d, hidden_size=hidden_size, num_blocks=num_blocks)
elif reshape_method == "project":
if projection_tensors is None:
raise ValueError(
"Projection tensors is None for reshape_method project")
x_reshaped = project_hidden(
inputs_3d,
projection_tensors=projection_tensors,
hidden_size=hidden_size,
num_blocks=num_blocks)
else:
raise ValueError("Unknown reshape_method")
x_res = tf.reshape(x_reshaped,
[-1] + common_layers.shape_list(x_reshaped)[2:])
x_means_hot = []
x_means = 0
extra_loss = 0
for i in range(num_residuals):
x_means_hot_res, x_means_res, q_loss_res, e_loss_res, neg_q_entropy = (
embedding_lookup(
x_reshaped,
means=means[i],
num_blocks=num_blocks,
block_v_size=block_v_size,
bottleneck_kind=bottleneck_kind,
random_top_k=random_top_k,
soft_em=soft_em,
num_samples=num_samples,
temperature_warmup_steps=temperature_warmup_steps,
do_hard_gumbel_softmax=do_hard_gumbel_softmax,
num_flows=num_flows,
approximate_gs_entropy=approximate_gs_entropy,
sum_over_latents=sum_over_latents))
# Update the EMA variables.
if ema:
tf.logging.info("Using EMA with beta = {}".format(beta))
updated_ema_count_res = moving_averages.assign_moving_average(
ema_count[i],
tf.where(cond,
tf.reduce_sum(
tf.reshape(x_means_hot_res,
shape=[-1, num_blocks, block_v_size]),
axis=0), ema_count[i]),
decay,
zero_debias=False)
dw = tf.matmul(
tf.transpose(x_means_hot_res, perm=[1, 2, 0]),
tf.transpose(x_res, perm=[1, 0, 2]))
updated_ema_means_res = moving_averages.assign_moving_average(
ema_means[i], tf.where(cond, dw, ema_means[i]),
decay, zero_debias=False)
n = tf.reduce_sum(updated_ema_count_res, axis=-1, keep_dims=True)
updated_ema_count_res = (
(updated_ema_count_res + epsilon) / (n + 2**z_size * epsilon) * n)
# pylint: disable=g-no-augmented-assignment
updated_ema_means_res = updated_ema_means_res / tf.expand_dims(
updated_ema_count_res, axis=-1)
# pylint: enable=g-no-augmented-assignment
with tf.control_dependencies([e_loss_res]):
update_means_res = tf.assign(means[i],
tf.where(cond,
updated_ema_means_res,
means[i]))
with tf.control_dependencies([update_means_res]):
extra_loss += beta * e_loss_res
else:
extra_loss += q_loss_res + beta * e_loss_res
# Update the residuals.
x_res -= x_means_res
x_means += x_means_res
x_means_hot.append(x_means_hot_res)
# Get the discrete latent representation.
x_means_hot = tf.stack(x_means_hot, axis=1)
x_means_idx = tf.argmax(x_means_hot, axis=-1)
# Get the binary representation.
x_means_bits = int_to_bit(
x_means_idx,
num_bits=int(z_size / (num_residuals * num_blocks)),
base=2)
shape = common_layers.shape_list(x_means_bits)
new_shape = shape[:-2]
new_shape[-1] = z_size
x_means_bits = tf.reshape(x_means_bits, shape=new_shape)
outputs_discrete = bit_to_int(
tf.to_int32(x_means_bits), num_bits=z_size, base=2)
# Adjust shape of discrete outputs.
inputs_shape = common_layers.shape_list(inputs)
outputs_discrete = tf.reshape(outputs_discrete, inputs_shape[:-1])
# If we're using soft EM then set discretes to the hot representation.
if soft_em:
outputs_discrete = x_means_hot
outputs_discrete = tf.reshape(outputs_discrete,
inputs_shape[:-1] + [block_v_size])
# Reshape assuming hidden_size == inputs_shape[:-1].
x_means = tf.reshape(x_means, inputs_shape)
outputs_dense = inputs + tf.stop_gradient(x_means - inputs)
elif bottleneck_kind == "gumbel-softmax":
_, outputs_hot, extra_loss = gumbel_softmax(
inputs,
z_size=z_size,
mode=mode,
softmax_k=softmax_k,
temperature_warmup_steps=temperature_warmup_steps,
summary=summary,
name=name)
outputs_discrete = tf.argmax(outputs_hot, axis=-1)
outputs_dense = tf.layers.dense(
outputs_hot, hidden_size, name="dae_dense")
neg_q_entropy = tf.constant(0.0)
elif bottleneck_kind == "semhash":
outputs_discrete = tf.layers.dense(inputs, z_size, name="vcc")
y_clean = common_layers.saturating_sigmoid(outputs_discrete)
if summary:
tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
if noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN:
noise = tf.truncated_normal(
common_layers.shape_list(outputs_discrete),
mean=0.0,
stddev=noise_dev)
y = common_layers.saturating_sigmoid(outputs_discrete + noise)
else:
y = y_clean
d = tf.to_float(tf.less(0.5, y))
y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
pd = common_layers.inverse_exp_decay(startup_steps * 2)
pd *= discrete_mix
pd = pd if mode == tf.estimator.ModeKeys.TRAIN else 1.0
c = tf.where(
tf.less(tf.random_uniform([common_layers.shape_list(y)[0]]), pd),
y_discrete, y)
outputs_dense_a = tf.layers.dense(c, filter_size, name="vch1a")
outputs_dense_b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
outputs_dense = outputs_dense_a + outputs_dense_b
outputs_dense = tf.layers.dense(outputs_dense, hidden_size,
name="vch_final_linear")
dx = tf.to_int32(tf.stop_gradient(d))
outputs_discrete = bit_to_int(dx, z_size)
extra_loss = tf.constant(0.0)
neg_q_entropy = tf.constant(0.0)
elif bottleneck_kind == "vae":
outputs_discrete, extra_loss, _, _ = vae(inputs, z_size, name="vae")
outputs_dense = tf.layers.dense(
outputs_discrete, filter_size, name="vch1")
neg_q_entropy = tf.constant(0.0)
else:
raise ValueError("Unknown discretization method.")
return outputs_dense, outputs_discrete, extra_loss, embed_fn, neg_q_entropy