def body_sharded()

in tensor2tensor/models/research/aligned.py [0:0]


  def body_sharded(self, sharded_features):
    # Remove dropout if not training
    hparams = self._hparams
    dp = self._data_parallelism
    x = dp(tf.squeeze, sharded_features["inputs"], 2)

    def preprocess(x):
      return dp(common_layers.layer_preprocess, x, hparams)

    def postprocess(x, y):
      return dp(common_layers.layer_postprocess, x, y, hparams)

    x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
    extra_loss = 0.0
    ffn_hidden_sizes = [int(s) for s in hparams.ffn_hidden_sizes.split(",")]
    if hparams.mask_right:

      def _bias(x):
        return common_attention.attention_bias_lower_triangle(
            common_layers.shape_list(x)[1])

      bias = dp(_bias, x)
    else:
      bias = tf.zeros([1, 1, 1, 1])

    batch_coordinate = dp(get_batch_coordinate, x)

    layers = hparams.layers.strip(",").split(",")
    for layer_num, layer_type in enumerate(layers):
      with tf.variable_scope("%s_%d" % (layer_type, layer_num)):
        if _should_preprocess(layer_type):
          x = preprocess(x)
        if layer_type == "timing":
          y = dp(common_attention.add_timing_signal_nd, x)
        elif layer_type == "pos_emb":
          y = dp(
              common_attention.add_positional_embedding_nd,
              x,
              hparams.max_length,
              name="pos_emb")
        elif layer_type == "att":
          y = dp(
              common_attention.multihead_attention,
              x,
              None,
              bias,  # bias
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout)
        elif layer_type == "att_grouped":
          multiplicative_overhead = (
              hparams.multiplicative_overhead if hparams.mode == ModeKeys.TRAIN
              else hparams.multiplicative_overhead_eval)
          y, loss = dp(
              common_attention.grouped_attention_multihead,
              x,
              x,
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              num_groups=hparams.attention_num_groups,
              memory_target_density=hparams.memory_target_density,
              multiplicative_overhead=multiplicative_overhead,
              make_image_summary=hparams.attention_image_summary,
              mask_right=hparams.mask_right,
          )
          extra_loss += tf.add_n(loss) / dp.n
        elif layer_type == "att_memory_efficient":
          assert hparams.layer_preprocess_sequence == "n"
          y = dp(common_attention.multihead_self_attention_memory_efficient, x,
                 bias, hparams.num_heads)
        elif layer_type == "att_local":
          y = dp(
              common_attention.multihead_attention,
              x,
              None,
              None,  # bias
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              attention_type=("local_mask_right"
                              if hparams.mask_right else "local_unmasked"),
              block_length=hparams.local_attention_window,
              block_width=hparams.local_attention_window)
        elif layer_type == "att_pseudolocal":
          # This is an inefficient implementation of local attention, for the
          # purpose of testing model quality.
          def _pseudolocal_bias(x):
            return common_attention.attention_bias_local(
                common_layers.shape_list(x)[1], hparams.local_attention_window,
                0 if hparams.mask_right else hparams.local_attention_window)

          pseudolocal_bias = dp(_pseudolocal_bias, x)
          y = dp(common_attention.multihead_attention, x, None,
                 pseudolocal_bias, hparams.attention_key_channels or
                 hparams.hidden_size, hparams.attention_value_channels or
                 hparams.hidden_size, hparams.hidden_size, hparams.num_heads,
                 hparams.attention_dropout)
        elif layer_type == "att_local_expert":
          y, loss = dp(
              common_attention.local_expert_attention,
              x,
              k=hparams.attention_moe_k,
              loss_coef=hparams.attention_load_balance,
              attention_num_experts=hparams.attention_num_experts,
              train=hparams.mode == ModeKeys.TRAIN,
              batch_coordinate=batch_coordinate,
              mask_right=hparams.mask_right,
              split_batch=bool(hparams.attention_split_batch),
              attention_kq_size=hparams.attention_kq_size,
              attention_v_size=hparams.attention_v_size)
          # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
          extra_loss += tf.add_n(loss) / dp.n
        elif layer_type == "att_lsh":
          if hparams.lsh_truncated:
            attention_fn = common_attention.multihead_attention_sparse_truncated
          else:
            attention_fn = common_attention.multihead_attention_sparse_dot_prod
          y, loss = dp(
              attention_fn,
              x,
              None,
              None,  # Bias is computed inside
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,

              # Additional parameters
              bi=[
                  common_attention.BatchInfo(
                      coordinates=batch_coordinate[i],
                      order=None,  # No future mask
                  ) for i in range(dp.n)
              ],
              use_map_fn=False,
              experts_params=dict(nb_hyperplanes=4,))
          extra_loss += tf.add_n(loss) / dp.n
        elif layer_type == "ffn":
          y = dp(
              expert_utils.ffn_expert_fn(hparams.hidden_size, ffn_hidden_sizes,
                                         hparams.hidden_size),
              dp(expert_utils.flatten_all_but_last, x))
          y = dp(common_layers.reshape_like, y, x)
        elif layer_type == "conv":
          y = dp(
              common_layers.conv1d,
              x,
              hparams.hidden_size,
              hparams.kernel_height,
              activation=tf.nn.relu,
              padding="SAME",
          )
        else:
          assert False, "unknown sublayer %s" % layer_type
        if _should_postprocess(layer_type):
          x = postprocess(x, y)
        else:
          x = y
    x = preprocess(x)

    decoder_output = dp(tf.expand_dims, x, 2)
    return decoder_output, extra_loss