in tensor2tensor/models/research/transformer_vae.py [0:0]
def ae_transformer_internal(inputs,
targets,
target_space,
hparams,
cache=None,
predict_mask=1.0):
"""AE Transformer, main step used for training."""
# Summaries break with the do_refine cond, turn them off in that case.
global _DO_SUMMARIES
if hparams.do_refine:
_DO_SUMMARIES = False
# Prepare.
if inputs is not None:
batch_size = common_layers.shape_list(inputs)[0]
else:
batch_size = common_layers.shape_list(targets)[0]
targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])
# Encoder.
if inputs is not None:
inputs = common_layers.flatten4d3d(inputs)
inputs, ed = encode(inputs, target_space, hparams, "input_enc")
inputs_ex, ed_ex = inputs, ed
else:
ed, inputs_ex, ed_ex = None, None, None
# Autoencoding.
losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0),
"neg_q_entropy": tf.constant(0.0)}
if hparams.do_ae:
# flatten here
original_targets = targets
original_targets_shape = tf.shape(original_targets)
if hparams.task == "image":
cia.maybe_reshape_4d_to_3d(targets)
if hparams.task == "translate":
if inputs is not None:
max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
else:
max_targets_len_from_inputs = targets
else:
assert hparams.task == "image"
max_targets_len_from_inputs = targets
if hparams.word_shuffle:
tf.logging.info("Using word shuffle with rate = {}".format(
hparams.word_shuffle))
targets_idx = tf.range(start=0,
limit=common_layers.shape_list(targets)[1],
delta=1)
targets_idx = tf.to_float(targets_idx)
noise = tf.random_uniform(shape=common_layers.shape_list(targets_idx),
minval=0,
maxval=1 + hparams.word_shuffle)
targets_idx += noise
permutation = contrib.framework().argsort(targets_idx)
targets_permuted = tf.gather(targets, indices=permutation, axis=1)
targets = targets_permuted
targets, _ = common_layers.pad_to_same_length(
targets, max_targets_len_from_inputs,
final_length_divisible_by=2**hparams.num_compress_steps)
# Add positional information
targets_shape = common_layers.shape_list(targets)
targets = tf.reshape(targets, [targets_shape[0], targets_shape[1],
targets_shape[3]])
targets = common_attention.add_positional_embedding(
targets, hparams.max_length, name="targets_position")
targets = tf.reshape(targets, shape=targets_shape)
if hparams.word_dropout:
mask = tf.random_uniform(shape=common_layers.shape_list(targets),
minval=0.0, maxval=1.0)
targets_noisy = tf.where(mask > hparams.word_dropout, targets,
tf.zeros_like(targets))
else:
targets_noisy = targets
targets_c = compress(targets_noisy, inputs, False, hparams, "compress")
if hparams.mode != tf.estimator.ModeKeys.PREDICT:
# Compress and bottleneck.
latents_dense, latents_discrete, extra_loss, embed, neg_q_entropy = (
hparams.bottleneck(inputs=targets_c,
filter_size=hparams.compress_filter_size,
mode=hparams.mode,
name="vc"))
if _DO_SUMMARIES:
tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
pc = common_layers.inverse_exp_decay(hparams.startup_steps)
pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
cond = tf.less(tf.random_uniform([batch_size]), pc)
latents_dense = tf.where(cond, latents_dense, targets_c)
# TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
# Extra loss predicting latent code from input. Discrete only.
if hparams.bottleneck_kind not in ["dense", "vae"]:
latents_pred = decode_transformer(
inputs_ex, ed_ex,
embed(latents_discrete), hparams, "extra",
task="translate")
_, latent_pred_loss = ae_latent_softmax(
latents_pred, tf.stop_gradient(latents_discrete), hparams)
# Scale by latent dimension for summary so we can compare across
# batches.
if _DO_SUMMARIES:
tf.summary.scalar("latent_pred_loss_mean",
tf.reduce_mean(latent_pred_loss))
if hparams.sum_over_latents:
latent_pred_loss = tf.reduce_sum(latent_pred_loss, [1, 2])
losses["latent_pred"] = tf.reduce_mean(
latent_pred_loss * tf.to_float(cond)) * hparams.prior_scale
losses["neg_q_entropy"] = neg_q_entropy * hparams.entropy_scale
else:
inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
losses["latent_pred"] = tf.reduce_mean(
tf.squared_difference(inputs_c, targets_c)) * 20
def bn_inputs():
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
bn, _, _, _, _ = hparams.bottleneck(
inputs=inputs_c,
filter_size=hparams.compress_filter_size,
mode=hparams.mode,
name="vc")
return bn
inputs_c = bn_inputs()
ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc),
latents_dense, inputs_c)
else:
if hparams.bottleneck_kind in ["dense", "vae"]:
inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
latents_dense, _, _, _, _ = hparams.bottleneck(
inputs=inputs_c,
filter_size=hparams.compress_filter_size,
mode=hparams.mode,
name="vc")
else:
latent_len = common_layers.shape_list(targets_c)[1]
_, _, _, embed, _ = hparams.bottleneck(
inputs=targets_c,
filter_size=hparams.compress_filter_size,
name="vc")
latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
if cache is None:
cache = ae_latent_sample(
latents_dense, inputs_ex, ed_ex, embed, 16, hparams)
latents_dense = embed(cache)
# Postprocess.
d = latents_dense
d_shape = common_layers.shape_list(d)
d = tf.reshape(d, [d_shape[0], d_shape[1], d_shape[3]])
d = common_attention.add_positional_embedding(
d, hparams.max_length, name="latents_position")
d = tf.reshape(d, shape=d_shape)
# decompressing the dense latents
for i in range(hparams.num_compress_steps):
j = hparams.num_compress_steps - i - 1
d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
if inputs is not None and hparams.do_attend_decompress:
d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)
# Masking.
if hparams.do_mask:
masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps)
masking *= common_layers.inverse_exp_decay(
hparams.mask_startup_steps // 4) # Not much at start.
if not hparams.do_refine:
masking -= tf.random_uniform([]) * hparams.unmasked_percentage
masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
if hparams.use_predict_mask:
masking = predict_mask
if hparams.mode == tf.estimator.ModeKeys.PREDICT:
masking = predict_mask
mask = tf.less(masking, tf.random_uniform(
common_layers.shape_list(targets)[:-1]))
mask = tf.expand_dims(tf.to_float(mask), 3)
# targets is always [batch, length, 1, depth]
targets = mask * targets + (1.0 - mask) * d
# reshape back to 4d here
if hparams.task == "image":
targets = tf.reshape(targets, original_targets_shape)
else:
targets = d
res = decode_transformer(inputs, ed, targets, hparams, "decoder",
causal=hparams.causal)
if hparams.do_ae:
if hparams.do_mask and hparams.do_refine:
def refine_res():
# return residual_conv(res, 1, (5, 1), hparams, "refine")
r, _ = encode(tf.squeeze(res, axis=[2]),
target_space, hparams, "refine_enc")
return tf.expand_dims(r, axis=2)
masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
all_masked = tf.less(masked_batches, 0.1)
res = tf.where(all_masked, refine_res(), res)
# We'll start training the extra model of latents after mask_startup_steps.
nonlatent_steps = hparams.mask_startup_steps
latent_time = tf.less(nonlatent_steps,
tf.to_int32(tf.train.get_global_step()))
losses["latent_pred"] *= tf.to_float(latent_time)
# res was generated from padded targets, which means it has some extra
# elements. These can cause shape problems when computing loss with respect to
# the original (unpadded) targets. So we remove their extra elements here.
res = res[:, :original_targets_shape[1], :, :]
data_dim = common_layers.shape_list(res)[1]
latent_dim = common_layers.shape_list(targets_c)[1]
return res, losses, cache, data_dim, latent_dim