def body_sharded()

in tensor2tensor/models/research/attention_lm_moe.py [0:0]


  def body_sharded(self, sharded_features):
    # Remove dropout if not training
    hparams = self._hparams
    dp = self._data_parallelism
    if hparams.use_inputs:
      decoder_input = dp(tf.squeeze, sharded_features["inputs"], 2)
      decoder_self_attention_bias = None
    else:
      targets = sharded_features["targets"]
      targets = dp(tf.squeeze, targets, 2)
      (decoder_input, decoder_self_attention_bias, pad_remover) = dp(
          attention_lm_moe_prepare_decoder, targets, hparams)

    def preprocess(x):
      return dp(common_layers.layer_preprocess, x, hparams)

    def postprocess(x, y):
      return dp(common_layers.layer_postprocess, x, y, hparams)

    x = dp(tf.nn.dropout, decoder_input,
           1.0 - hparams.layer_prepostprocess_dropout)
    extra_loss = 0.0

    if not hparams.use_inputs:
      # As preprocess and postprocess are called with batch of size one (all
      # batches concatenated), we just make sure that batch_norm is not use (
      # should not either way)
      assert hparams.norm_type != "batch"

      tf.logging.info("Applying Padding Remover for the attention experts")

      dp_remove_pad = functools.partial(
          dp, remove_pad, pad_remover=pad_remover, mode=hparams.mode)
      dp_restore_pad = functools.partial(
          dp, restore_pad, ref_x=x, pad_remover=pad_remover, mode=hparams.mode)
    else:
      # Using identity function: No effect
      dp_remove_pad = lambda x: x
      dp_restore_pad = lambda x: x

    if hparams.attention_exp_factor != 0:
      tf.logging.info("Expand/compress tokens before sending them to experts")
      dp_expand_bc = lambda x: dp(  # pylint: disable=g-long-lambda
          expand_batch_coordinates,
          x,
          hparams.attention_exp_factor)
      dp_expand_x = lambda x: dp(  # pylint: disable=g-long-lambda
          common_attention.deconv_elems_1d,
          x,
          hparams.attention_exp_factor,
          hparams.attention_exp_inputdim)
      dp_compress_x = lambda x, l: dp(  # pylint: disable=g-long-lambda
          common_attention.conv_elems_1d,
          x,
          hparams.attention_exp_factor,
          l)
    else:
      dp_expand_bc = lambda x: x
      dp_expand_x = lambda x: x
      dp_compress_x = lambda x, l: x

    def print_shape(x, suffix, debug=False):
      # To help debugging, print the input/output shapes at inference and eval
      # Inference for long sequences can take a long time, so that's help to
      # see the progression of the generation
      if not debug and hparams.mode == ModeKeys.TRAIN:
        return x
      return tf.Print(x, [tf.shape(x)], "shape_x_{}".format(suffix))

    with tf.name_scope("batch_coordinate_preprocess"):
      batch_coordinate = dp(get_batch_coordinate, x)
      batch_coordinate = dp_remove_pad(batch_coordinate)
      batch_coordinate = dp_expand_bc(batch_coordinate)
      batch_order = dp(get_batch_coordinate, x, axis=-1)
      batch_order = dp_remove_pad(batch_order)
      batch_order = dp_expand_bc(batch_order)

    x = dp(print_shape, x, "in")

    assert hparams.batch_size >= hparams.max_length

    num_hidden_layers = (
        len(hparams.attention_layers) or hparams.num_hidden_layers)
    for layer in range(num_hidden_layers):
      with tf.variable_scope("layer_%d" % layer):

        # Use the layer type defined in attention_layers
        if hparams.attention_layers:
          attention_type = LAYER_SYMBOLS[hparams.attention_layers[layer]]
        else:
          attention_type = hparams.attention_type

        with tf.variable_scope(
            "attention_{}".format(attention_type)):
          if attention_type in [
              AttentionType.MULTIHEAD, AttentionType.MULTIHEAD_FULL]:
            attention_dot_type = (
                "local_mask_right" if hparams.attention_local else
                "dot_product")
            if attention_type == AttentionType.MULTIHEAD_FULL:
              attention_dot_type = "dot_product"
            y = dp(
                common_attention.multihead_attention,
                preprocess(x),
                None,
                decoder_self_attention_bias,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                attention_type=attention_dot_type,
                block_length=hparams.attention_block_length,
                name="decoder_self_attention")
          elif attention_type == AttentionType.SPARSE_MULTIHEAD:
            x_in = preprocess(x)
            x_in = dp_remove_pad(x_in)
            y, loss_experts = dp(
                common_attention.multihead_attention_sparse_dot_prod,
                x_in,
                None,
                None,  # Bias is computed inside
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,

                # Additional parameters
                bi=[common_attention.BatchInfo(
                    coordinates=batch_coordinate[i],
                    order=batch_order[i],  # No future mask
                ) for i in range(dp.n)],
                use_map_fn=hparams.lsh_use_map_fn,
                experts_params=dict(
                    nb_hyperplanes=hparams.lsh_num_hyperplanes,
                ),
            )
            y = dp_restore_pad(y)

            # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
            extra_loss += tf.add_n(loss_experts) / dp.n
          elif attention_type == AttentionType.SPARSE_MULTIHEAD_TRUNCATED:
            x_in = preprocess(x)
            y, loss_experts = dp(
                common_attention.multihead_attention_sparse_truncated,
                x_in,
                None,
                None,  # Bias is computed inside
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,

                # Additional parameters
                bi=[common_attention.BatchInfo(
                    coordinates=batch_coordinate[i],
                    order=batch_order[i],  # No future mask
                ) for i in range(dp.n)],
                mask_right=True,
                experts_params=dict(
                    nb_hyperplanes=hparams.lsh_num_hyperplanes,
                ),
            )

            # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
            extra_loss += tf.add_n(loss_experts) / dp.n
          elif attention_type == AttentionType.MEMORY_EFFICIENT:
            assert hparams.layer_preprocess_sequence == "n"
            y = dp(
                common_attention.multihead_self_attention_memory_efficient,
                x,
                decoder_self_attention_bias,
                hparams.num_heads,
                name="decoder_self_attention")
          elif attention_type == AttentionType.MULTIHEAD_REDUCED:
            y = dp(
                common_attention.multihead_self_attention_reduced,
                preprocess(x),
                factor=hparams.attention_red_factor,
                reduction_type=hparams.attention_reduction_type,
                nonlinearity=hparams.attention_nonlinearity,
                multihead_params=dict(
                    total_key_depth=
                    hparams.attention_key_channels or hparams.hidden_size,
                    total_value_depth=
                    hparams.attention_value_channels or hparams.hidden_size,
                    num_heads=hparams.num_heads,
                    dropout_rate=hparams.attention_dropout,
                ))
          elif attention_type == AttentionType.LOCAL_EXPERTS:
            x_in = preprocess(x)
            x_in = dp_remove_pad(x_in)
            x_in = dp_expand_x(x_in)
            y, loss = dp(
                common_attention.local_expert_attention,
                x_in,
                k=hparams.attention_moe_k,
                loss_coef=hparams.attention_load_balance,
                attention_num_experts=hparams.attention_num_experts,
                train=hparams.mode == ModeKeys.TRAIN,
                batch_coordinate=batch_coordinate,
                mask_right=not hparams.use_inputs,
                split_batch=bool(hparams.attention_split_batch),
                attention_num_head=hparams.attention_num_head,
                attention_kq_size=hparams.attention_kq_size,
                attention_v_size=hparams.attention_v_size)
            y = dp_compress_x(y, x[0].get_shape().as_list()[-1])
            y = dp_restore_pad(y)
            # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
            extra_loss += tf.add_n(loss) / dp.n
          else:
            raise ValueError("Only {} supported for now.".format(
                AttentionType.get_choices()))
          x = postprocess(x, y)
        with tf.variable_scope("ffn"):
          if hparams.memory_efficient_ffn:
            assert hparams.layer_preprocess_sequence == "n"
            y = dp(
                common_layers.conv_hidden_relu_memory_efficient,
                x,
                hparams.filter_size)
          else:
            additional_conv_params = {}
            if hparams.use_sepconv:
              additional_conv_params = dict(
                  padding="LEFT",
                  # Parameters copied from the transformer model
                  kernel_size=(3, 1),
                  second_kernel_size=(31, 1),
              )
            y = dp(
                common_layers.conv_hidden_relu,
                preprocess(x),
                hparams.filter_size,
                hparams.hidden_size,
                dropout=hparams.relu_dropout,
                **additional_conv_params
            )
          x = postprocess(x, y)
    x = preprocess(x)

    decoder_output = dp(tf.expand_dims, x, 2)
    return decoder_output, extra_loss