def ae_transformer_internal()

in tensor2tensor/models/research/transformer_vae.py [0:0]


def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
  """AE Transformer, main step used for training."""
  # Summaries break with the do_refine cond, turn them off in that case.
  global _DO_SUMMARIES
  if hparams.do_refine:
    _DO_SUMMARIES = False

  # Prepare.
  if inputs is not None:
    batch_size = common_layers.shape_list(inputs)[0]
  else:
    batch_size = common_layers.shape_list(targets)[0]
  targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

  # Encoder.
  if inputs is not None:
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed = encode(inputs, target_space, hparams, "input_enc")
    inputs_ex, ed_ex = inputs, ed
  else:
    ed, inputs_ex, ed_ex = None, None, None

  # Autoencoding.
  losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0),
            "neg_q_entropy": tf.constant(0.0)}
  if hparams.do_ae:
    # flatten here
    original_targets = targets
    original_targets_shape = tf.shape(original_targets)
    if hparams.task == "image":
      cia.maybe_reshape_4d_to_3d(targets)
    if hparams.task == "translate":
      if inputs is not None:
        max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
      else:
        max_targets_len_from_inputs = targets
    else:
      assert hparams.task == "image"
      max_targets_len_from_inputs = targets
    if hparams.word_shuffle:
      tf.logging.info("Using word shuffle with rate = {}".format(
          hparams.word_shuffle))
      targets_idx = tf.range(start=0,
                             limit=common_layers.shape_list(targets)[1],
                             delta=1)
      targets_idx = tf.to_float(targets_idx)
      noise = tf.random_uniform(shape=common_layers.shape_list(targets_idx),
                                minval=0,
                                maxval=1 + hparams.word_shuffle)
      targets_idx += noise
      permutation = contrib.framework().argsort(targets_idx)
      targets_permuted = tf.gather(targets, indices=permutation, axis=1)
      targets = targets_permuted
    targets, _ = common_layers.pad_to_same_length(
        targets, max_targets_len_from_inputs,
        final_length_divisible_by=2**hparams.num_compress_steps)
    # Add positional information
    targets_shape = common_layers.shape_list(targets)
    targets = tf.reshape(targets, [targets_shape[0], targets_shape[1],
                                   targets_shape[3]])
    targets = common_attention.add_positional_embedding(
        targets, hparams.max_length, name="targets_position")
    targets = tf.reshape(targets, shape=targets_shape)
    if hparams.word_dropout:
      mask = tf.random_uniform(shape=common_layers.shape_list(targets),
                               minval=0.0, maxval=1.0)
      targets_noisy = tf.where(mask > hparams.word_dropout, targets,
                               tf.zeros_like(targets))
    else:
      targets_noisy = targets

    targets_c = compress(targets_noisy, inputs, False, hparams, "compress")
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      # Compress and bottleneck.
      latents_dense, latents_discrete, extra_loss, embed, neg_q_entropy = (
          hparams.bottleneck(inputs=targets_c,
                             filter_size=hparams.compress_filter_size,
                             mode=hparams.mode,
                             name="vc"))
      if _DO_SUMMARIES:
        tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
      pc = common_layers.inverse_exp_decay(hparams.startup_steps)
      pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      cond = tf.less(tf.random_uniform([batch_size]), pc)
      latents_dense = tf.where(cond, latents_dense, targets_c)
      # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
      losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
      # Extra loss predicting latent code from input. Discrete only.
      if hparams.bottleneck_kind not in ["dense", "vae"]:
        latents_pred = decode_transformer(
            inputs_ex, ed_ex,
            embed(latents_discrete), hparams, "extra",
            task="translate")
        _, latent_pred_loss = ae_latent_softmax(
            latents_pred, tf.stop_gradient(latents_discrete), hparams)

        # Scale by latent dimension for summary so we can compare across
        # batches.
        if _DO_SUMMARIES:
          tf.summary.scalar("latent_pred_loss_mean",
                            tf.reduce_mean(latent_pred_loss))
        if hparams.sum_over_latents:
          latent_pred_loss = tf.reduce_sum(latent_pred_loss, [1, 2])

        losses["latent_pred"] = tf.reduce_mean(
            latent_pred_loss * tf.to_float(cond)) * hparams.prior_scale
        losses["neg_q_entropy"] = neg_q_entropy * hparams.entropy_scale
      else:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        losses["latent_pred"] = tf.reduce_mean(
            tf.squared_difference(inputs_c, targets_c)) * 20
        def bn_inputs():
          with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            bn, _, _, _, _ = hparams.bottleneck(
                inputs=inputs_c,
                filter_size=hparams.compress_filter_size,
                mode=hparams.mode,
                name="vc")
          return bn
        inputs_c = bn_inputs()
        ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
        ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
        latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc),
                                 latents_dense, inputs_c)
    else:
      if hparams.bottleneck_kind in ["dense", "vae"]:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        latents_dense, _, _, _, _ = hparams.bottleneck(
            inputs=inputs_c,
            filter_size=hparams.compress_filter_size,
            mode=hparams.mode,
            name="vc")
      else:
        latent_len = common_layers.shape_list(targets_c)[1]
        _, _, _, embed, _ = hparams.bottleneck(
            inputs=targets_c,
            filter_size=hparams.compress_filter_size,
            name="vc")
        latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
        if cache is None:
          cache = ae_latent_sample(
              latents_dense, inputs_ex, ed_ex, embed, 16, hparams)
        latents_dense = embed(cache)
    # Postprocess.
    d = latents_dense
    d_shape = common_layers.shape_list(d)
    d = tf.reshape(d, [d_shape[0], d_shape[1], d_shape[3]])
    d = common_attention.add_positional_embedding(
        d, hparams.max_length, name="latents_position")
    d = tf.reshape(d, shape=d_shape)

    # decompressing the dense latents
    for i in range(hparams.num_compress_steps):
      j = hparams.num_compress_steps - i - 1
      d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
      if inputs is not None and hparams.do_attend_decompress:
        d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
      d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

    # Masking.
    if hparams.do_mask:
      masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps)
      masking *= common_layers.inverse_exp_decay(
          hparams.mask_startup_steps // 4)  # Not much at start.
      if not hparams.do_refine:
        masking -= tf.random_uniform([]) * hparams.unmasked_percentage
      masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
      if hparams.use_predict_mask:
        masking = predict_mask
      if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        masking = predict_mask
      mask = tf.less(masking, tf.random_uniform(
          common_layers.shape_list(targets)[:-1]))
      mask = tf.expand_dims(tf.to_float(mask), 3)

      # targets is always [batch, length, 1, depth]
      targets = mask * targets + (1.0 - mask) * d
      # reshape back to 4d here
      if hparams.task == "image":
        targets = tf.reshape(targets, original_targets_shape)
    else:
      targets = d

  res = decode_transformer(inputs, ed, targets, hparams, "decoder",
                           causal=hparams.causal)
  if hparams.do_ae:
    if hparams.do_mask and hparams.do_refine:
      def refine_res():
        # return residual_conv(res, 1, (5, 1), hparams, "refine")
        r, _ = encode(tf.squeeze(res, axis=[2]),
                      target_space, hparams, "refine_enc")
        return tf.expand_dims(r, axis=2)
      masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
      all_masked = tf.less(masked_batches, 0.1)
      res = tf.where(all_masked, refine_res(), res)
    # We'll start training the extra model of latents after mask_startup_steps.
    nonlatent_steps = hparams.mask_startup_steps
    latent_time = tf.less(nonlatent_steps,
                          tf.to_int32(tf.train.get_global_step()))
    losses["latent_pred"] *= tf.to_float(latent_time)

  # res was generated from padded targets, which means it has some extra
  # elements. These can cause shape problems when computing loss with respect to
  # the original (unpadded) targets. So we remove their extra elements here.
  res = res[:, :original_targets_shape[1], :, :]

  data_dim = common_layers.shape_list(res)[1]
  latent_dim = common_layers.shape_list(targets_c)[1]
  return res, losses, cache, data_dim, latent_dim