mesh_tensorflow/transformer/funnel_transformer.py [387:436]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    encoder_sequence_id = mtf.layers.rename_length_to_memory_length(
        encoder_sequence_id)
    batch_dims = inputs.shape[:-1]
    length_dim = inputs.shape[-1]
    if max_decode_length is None:
      decode_length_dim = length_dim
    else:
      decode_length_dim = mtf.Dimension("length", max_decode_length)
    if beam_size == 1:
      ids_shape = mtf.Shape(batch_dims + [decode_length_dim])
      partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32)
      return self.decoder.sample_autoregressive(
          partial_sequences,
          temperature=temperature,
          sampling_keep_top_k=sampling_keep_top_k,
          variable_dtype=variable_dtype,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          encoder_inputs=mtf.layers.rename_length_to_memory_length(inputs),
          shared_params=shared_params,
          has_partial_sequences=False,
          encoder_layer_outputs=encoder_layer_outputs)
    else:
      if temperature != 0:
        raise ValueError(
            "don't know how to beam search with nonzero temperature")
      if sampling_keep_top_k != -1:
        raise ValueError(
            "don't know how to beam search with top-k value other than -1.")
      # beam search
      beam_dim = mtf.Dimension("beam", beam_size)
      ids_shape = mtf.Shape(batch_dims + [beam_dim, decode_length_dim])
      partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32)
      input_length = mtf.reduce_sum(
          mtf.to_float(mtf.cast(inputs, tf.bool)),
          reduced_dim=length_dim)
      max_input_length = mtf.reduce_max(input_length)
      decode_length = mtf.cast(
          max_input_length * decode_length_multiplier
          + decode_length_constant, tf.int32)
      return self.decoder.beam_search(
          partial_sequences,
          decode_length,
          variable_dtype=variable_dtype,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          encoder_inputs=inputs,
          alpha=alpha,
          shared_params=shared_params,
          encoder_layer_outputs=encoder_layer_outputs)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



mesh_tensorflow/transformer/transformer.py [1650:1699]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    encoder_sequence_id = mtf.layers.rename_length_to_memory_length(
        encoder_sequence_id)
    batch_dims = inputs.shape[:-1]
    length_dim = inputs.shape[-1]
    if max_decode_length is None:
      decode_length_dim = length_dim
    else:
      decode_length_dim = mtf.Dimension("length", max_decode_length)
    if beam_size == 1:
      ids_shape = mtf.Shape(batch_dims + [decode_length_dim])
      partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32)
      return self.decoder.sample_autoregressive(
          partial_sequences,
          temperature=temperature,
          sampling_keep_top_k=sampling_keep_top_k,
          variable_dtype=variable_dtype,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          encoder_inputs=mtf.layers.rename_length_to_memory_length(inputs),
          shared_params=shared_params,
          has_partial_sequences=False,
          encoder_layer_outputs=encoder_layer_outputs)
    else:
      if temperature != 0:
        raise ValueError(
            "don't know how to beam search with nonzero temperature")
      if sampling_keep_top_k != -1:
        raise ValueError(
            "don't know how to beam search with top-k value other than -1.")
      # beam search
      beam_dim = mtf.Dimension("beam", beam_size)
      ids_shape = mtf.Shape(batch_dims + [beam_dim, decode_length_dim])
      partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32)
      input_length = mtf.reduce_sum(
          mtf.to_float(mtf.cast(inputs, tf.bool)),
          reduced_dim=length_dim)
      max_input_length = mtf.reduce_max(input_length)
      decode_length = mtf.cast(
          max_input_length * decode_length_multiplier
          + decode_length_constant, tf.int32)
      return self.decoder.beam_search(
          partial_sequences,
          decode_length,
          variable_dtype=variable_dtype,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          encoder_inputs=inputs,
          alpha=alpha,
          shared_params=shared_params,
          encoder_layer_outputs=encoder_layer_outputs)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



