def masked_relative_local_attention_1d()

in tensor2tensor/layers/common_attention.py [0:0]


def masked_relative_local_attention_1d(q,
                                       k,
                                       v,
                                       block_length=128,
                                       make_image_summary=False,
                                       dropout_rate=0.,
                                       heads_share_relative_embedding=False,
                                       add_relative_to_values=False,
                                       name=None):
  """Masked local 1d attention with relative positions.

  The sequence is divided into blocks of length block_size.
  Attention for a given query position can only see memory positions
  less than or equal to the query position, in the corresponding block
  and the previous block.

  If mask_right is True, then a target position cannot see greater source
  positions.

  Args:
    q: a Tensor with shape [batch, heads, length, depth_k]
    k: a Tensor with shape [batch, heads, length, depth_k]
    v: a Tensor with shape [batch, heads, length, depth_v]
    block_length: an integer
    make_image_summary: a boolean, whether to make an attention image summary.
    dropout_rate: Dropout rate for attention dropout
    heads_share_relative_embedding: a boolean for sharing relative embeddings.
    add_relative_to_values: a boolean for whether to add relative component to
        values.
    name: an optional string

  Returns:
    a Tensor of shape [batch, heads, length, depth_v]

  Raises:
    ValueError: wwhen the name for the variable scope is not passed.
  """
  if not name:
    raise ValueError("Name must be assigned since reuse for variable scope is "
                     "set to tf.AUTO_REUSE, in order to reuse relative "
                     "embeddings of keys and values.")

  # Reuse flag is set to auto_reuse to reuse relative embeddings of keys and
  # values across blocks (first and tail blocks).
  with tf.variable_scope(
      name, default_name="masked_relative_local_attention_1d",
      values=[q, k, v], reuse=tf.AUTO_REUSE):

    default_block_length = block_length
    batch = common_layers.shape_list(q)[0]
    heads = common_layers.shape_list(q)[1]
    length = common_layers.shape_list(q)[2]
    # If (length < 2 * block_length), then we use only one block.
    if isinstance(length, int) and isinstance(block_length, int):
      block_length = length if length < block_length * 2 else block_length
    else:
      block_length = tf.where(
          tf.less(length, block_length * 2), length, block_length)
    depth_k = common_layers.shape_list(k)[3]
    depth_v = common_layers.shape_list(v)[3]
    original_length = length
    padding_size = tf.mod(-length, block_length)
    length += padding_size
    padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]]
    q = tf.pad(q, padding)
    k = tf.pad(k, padding)
    v = tf.pad(v, padding)

    num_blocks = length // block_length
    # compute attention for the first query block.
    first_q = tf.slice(q, [0, 0, 0, 0], [-1, -1, block_length, -1])
    first_k = tf.slice(k, [0, 0, 0, 0], [-1, -1, block_length, -1])
    first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1])
    # Relative embeddings will be used later as well.
    # TODO(avaswani,annahuang): check why 2*bl was breaking for music
    # Needs to be known at static shape inference time, hence cannot be
    # 2 * block_length.
    rel_embed_length = 4 * default_block_length
    # We only multiply with the needed embeddings as we slice them out.
    first_rel_embeddings = get_relative_embeddings_left(
        rel_embed_length, block_length, depth_k, heads,
        heads_share_relative_embedding, "relative_embeddings")
    first_rel_logits = matmul_with_relative_keys(
        first_q, first_rel_embeddings, heads_share_relative_embedding)
    first_logits = tf.matmul(first_q, first_k, transpose_b=True)
    first_logits += (
        _relative_position_to_absolute_position_masked(first_rel_logits))
    # adding a mask
    first_logits += (
        common_layers.cast_like(attention_bias_lower_triangle(block_length),
                                first_logits))
    first_att = tf.nn.softmax(first_logits,
                              name="first_attention_weights")
    # dropping out the attention links for each of the heads
    first_att = common_layers.dropout_with_broadcast_dims(
        first_att, 1.0 - dropout_rate,
        broadcast_dims=None)
    # only call image summary for the first block
    if common_layers.should_generate_summaries() and make_image_summary:
      attention_image_summary(first_att, None)
    first_output = tf.matmul(first_att, first_v)

    # compute attention for all subsequent query blocks.
    q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k])
    k = tf.reshape(k, [batch, heads, num_blocks, block_length, depth_k])
    v = tf.reshape(v, [batch, heads, num_blocks, block_length, depth_v])
    local_k = _make_local_block(k, depth_k, batch, heads, num_blocks,
                                block_length)
    local_v = _make_local_block(v, depth_v, batch, heads, num_blocks,
                                block_length)
    tail_q = tf.slice(q, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
    tail_q = tf.reshape(tail_q,
                        [batch, heads, num_blocks - 1, block_length, depth_k])
    local_length = common_layers.shape_list(local_k)[3]

    # collapsing num blocks and batch size so that we can reuse
    # functions
    def _reshape_for_relative(x):
      x_shape = common_layers.shape_list(x)
      # [batch, num_blocks, heads, length, depth]
      x = tf.transpose(x, [0, 2, 1, 3, 4])
      x = tf.reshape(x, [batch*x_shape[2], heads, x_shape[3],
                         x_shape[4]])
      return x
    rel_tail_q = _reshape_for_relative(tail_q)
    rel_k = _reshape_for_relative(local_k)
    rel_v = _reshape_for_relative(local_v)
    rel_embeddings = get_relative_embeddings_left(
        rel_embed_length, 2 * block_length, depth_k, heads,
        heads_share_relative_embedding, "relative_embeddings")
    rel_logits = matmul_with_relative_keys(
        rel_tail_q, rel_embeddings, heads_share_relative_embedding)
    # Computing relative logits separately for the masked and unmasked parts
    # because the reshaping logic is different for both
    masked_rel_logits = tf.slice(rel_logits, [0, 0, 0, block_length],
                                 [-1, -1, -1, -1])
    masked_rel_logits = _relative_position_to_absolute_position_masked(
        masked_rel_logits)
    unmasked_rel_logits = tf.slice(rel_logits, [0, 0, 0, 0],
                                   [-1, -1, -1, 2*block_length-1])
    unmasked_rel_logits = _relative_position_to_absolute_position_unmasked(
        unmasked_rel_logits)
    all_rel_logits = tf.concat([unmasked_rel_logits, masked_rel_logits],
                               axis=3)
    all_logits = (
        tf.matmul(rel_tail_q, rel_k, transpose_b=True) + all_rel_logits)
    # make sure source_pos <= target_pos
    good_part = common_layers.ones_matrix_band_part(block_length,
                                                    local_length,
                                                    -1, block_length)
    mask = (1.0 - good_part) * -1e9
    mask = common_layers.cast_like(mask, all_logits)
    all_logits += tf.reshape(mask, [1, 1, block_length, local_length])
    weights = tf.nn.softmax(all_logits, name="attention_weights")
    # [batch (* num_blocks), heads, query_length (=block_length),
    # key_length (=2*block_length)]
    weights = common_layers.dropout_with_broadcast_dims(
        weights, 1.0 - dropout_rate,
        broadcast_dims=None)

    output = tf.matmul(weights, rel_v)
    if add_relative_to_values:
      # Adds the contribution of the weighted relative embeddings to the values.
      weights_for_unmasked, weights_for_masked = (
          tf.split(weights, 2, axis=3))
      rel_weights_unmasked = _absolute_position_to_relative_position_unmasked(
          weights_for_unmasked)
      rel_weights_masked = _absolute_position_to_relative_position_masked(
          weights_for_masked)

      value_rel_embeddings_unmasked = get_relative_embeddings_left(
          rel_embed_length, 2 * block_length, depth_v,
          heads, heads_share_relative_embedding,
          "value_relative_embeddings")
      # The unmasked part starts with index -1 as opposed 0 has take uptil last.
      if heads_share_relative_embedding:
        value_rel_embeddings_unmasked = value_rel_embeddings_unmasked[:-1, :]
      else:
        value_rel_embeddings_unmasked = value_rel_embeddings_unmasked[:, :-1, :]
      value_rel_embeddings_masked = get_relative_embeddings_left(
          rel_embed_length, block_length, depth_v,
          heads, heads_share_relative_embedding,
          "value_relative_embeddings")

      # [batch (*num_blocks), heads, query length, key length]
      rel_weights = tf.concat(
          [rel_weights_unmasked, rel_weights_masked], axis=3)
      if heads_share_relative_embedding:
        value_rel_embeddings_concat_axis = 0
      else:
        value_rel_embeddings_concat_axis = 1
      value_rel_embeddings = tf.concat(
          [value_rel_embeddings_unmasked, value_rel_embeddings_masked],
          axis=value_rel_embeddings_concat_axis)
      output_rel = matmul_with_relative_values(
          rel_weights, value_rel_embeddings, heads_share_relative_embedding)
      output += output_rel

    # bring to [batch, heads, num_blocks-1, block_length, depth]
    output = tf.reshape(output,
                        [batch, num_blocks-1, heads, block_length, depth_v])
    output = tf.transpose(output, [0, 2, 1, 3, 4])

    output = tf.reshape(
        output, [batch, heads, (num_blocks - 1) * block_length, depth_v])
    output = tf.concat([first_output, output], axis=2)
    output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
    output = tf.reshape(output, [batch, heads, original_length, depth_v])
    return output