def grouped_attention_multihead()

in tensor2tensor/layers/common_attention.py [0:0]


def grouped_attention_multihead(query_antecedent,
                                memory_antecedent,
                                total_key_depth,
                                total_value_depth,
                                output_depth,
                                num_heads,
                                num_groups,
                                memory_target_density=2.0,
                                multiplicative_overhead=1.25,
                                additive_overhead=8.0,
                                mask_right=False,
                                make_image_summary=True,
                                name=None):
  """Multi-head dot-product attention with sparsity.

  For each attention head, the queries are partitioned into groups.
  For each group, only a subset of the key-value pairs are considered.

  The choices of groups are selected based on trained predictors of
  the total attention given the group inclusion.

  memory_target_density indicates the average how many groups in which
  a key-value pair should participate.

  We use auxiliary losses to ensure that each group contains roughly
  the same number of queries and the same number of key-value pairs.
  If for a given sequence, the actual number of queries/pairs sent to
  an expert exceeds this target by a factor of more than
  multiplicative_overhead, then the last ones are dropped.  We use
  this drop-last policy to avoid bleeding information backwards, which
  is necessary when using this function with autoregressive
  prediction.

  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels]
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    num_groups: an integer
    memory_target_density: a floating point scalar
    multiplicative_overhead: a floating point scalar
    additive_overhead: a floating point scalar
    mask_right: a boolean
    make_image_summary: a boolean
    name: an optional string

  Returns:
    A Tensor with shape [batch, length_q, output_depth]

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
  batch = common_layers.shape_list(query_antecedent)[0]
  length_q = common_layers.shape_list(query_antecedent)[1]
  length_kv = common_layers.shape_list(memory_antecedent)[1]

  if total_key_depth % num_heads != 0:
    raise ValueError("Key depth (%d) must be divisible by the number of "
                     "attention heads (%d)." % (total_key_depth, num_heads))
  depth_qk = total_key_depth // num_heads
  if total_value_depth % num_heads != 0:
    raise ValueError("Value depth (%d) must be divisible by the number of "
                     "attention heads (%d)." % (total_value_depth, num_heads))
  depth_v = total_value_depth // num_heads
  with tf.variable_scope(
      name, default_name="multihead_attention_sparse",
      values=[query_antecedent, memory_antecedent]):
    q = common_layers.dense(
        query_antecedent, total_key_depth, use_bias=False, name="q_transform")
    kv = common_layers.dense(
        memory_antecedent,
        total_key_depth + total_value_depth,
        use_bias=False,
        name="kv_transform")
    q = split_heads(q, num_heads)
    kv = split_heads(kv, num_heads)
    # Make predictions about q_total and m_total.
    # These are used to determine group inclusion.
    # We will train these by auxiliary losses.  We use stop_gradient here
    # to keep these losses from back-propagating to the rest of the model.
    # We add biases that help balance the usage of the experts.
    q_pred = common_layers.dense(
        tf.stop_gradient(query_antecedent),
        num_heads * num_groups,
        use_bias=False,
        name="q_pred")
    q_pred = split_heads(q_pred, num_heads)
    q_bias = tf.get_variable("q_bias", [1, num_heads, 1, num_groups])
    q_pred_biased = q_pred + q_bias
    m_pred = common_layers.dense(
        tf.stop_gradient(memory_antecedent),
        num_heads * num_groups,
        use_bias=False,
        name="m_pred")
    m_pred = split_heads(m_pred, num_heads)
    m_bias = tf.get_variable("m_bias", [1, num_heads, 1, num_groups])
    m_pred_biased = m_pred + m_bias
    q *= depth_qk**-0.5
    # q, kv, q_pred, m_pred are all [batch, heads, length_[q/m], ?]
    # now reshape them all to [batch * heads, length, ?]
    q = combine_first_two_dimensions(q)
    kv = combine_first_two_dimensions(kv)
    q_pred = combine_first_two_dimensions(q_pred)
    m_pred = combine_first_two_dimensions(m_pred)
    q_pred_biased = combine_first_two_dimensions(q_pred_biased)
    m_pred_biased = combine_first_two_dimensions(m_pred_biased)
    q_group = tf.argmax(q_pred_biased, axis=2)
    q_requests = tf.one_hot(q_group, num_groups, axis=-1)
    m_requests = tf.to_float(tf.greater(m_pred_biased, 0.0))
    # include first memory position in all groups, to avoid division by zero.
    m_requests = tf.maximum(
        m_requests, tf.reshape(tf.one_hot([0], length_kv), [1, length_kv, 1]))
    q_group_size = tf.reduce_sum(q_requests, 1)
    m_group_size = tf.reduce_sum(m_requests, 1)
    q_group_target_size = tf.to_float(length_q) / tf.to_float(num_groups)
    m_group_target_size = (
        tf.to_float(length_kv) * memory_target_density /
        tf.to_float(num_groups))
    capacity_q = tf.minimum(
        length_q,
        tf.to_int32(q_group_target_size * multiplicative_overhead +
                    additive_overhead))
    capacity_m = tf.minimum(
        length_kv,
        tf.to_int32(m_group_target_size * multiplicative_overhead +
                    additive_overhead))
    q_dispatcher = expert_utils.TruncatingDispatcher(q_requests, capacity_q)
    m_dispatcher = expert_utils.TruncatingDispatcher(m_requests, capacity_m)
    q_gates = q_dispatcher.gates()
    m_gates = m_dispatcher.gates()
    dispatched_q = q_dispatcher.dispatch(q)
    dispatched_kv = m_dispatcher.dispatch(kv)
    # dispatched_q: [batch * num_heads, num_groups, capacity_q, depth_qk]
    # dispatched_kv:
    #   [batch * num_heads, num_groups, capacity_m, depth_qk + depth_v]
    k, v = tf.split(dispatched_kv, [depth_qk, depth_v], axis=3)
    logits = tf.matmul(dispatched_q, k, transpose_b=True)
    bias = tf.expand_dims((m_dispatcher.nonpadding() - 1.0) * 1e9, 2)
    if mask_right:
      q_coordinate = tf.to_float(
          tf.expand_dims(q_dispatcher.length_coordinate(), 3))
      m_coordinate = tf.to_float(
          tf.expand_dims(m_dispatcher.length_coordinate(), 2))
      bias += tf.to_float(tf.greater(m_coordinate, q_coordinate)) * -1e9
    logits += bias
    log_weights = tf.nn.log_softmax(logits)
    weights = tf.exp(log_weights)
    # For each query, this is the log of the sum of the unnormalized weights.
    q_total = tf.stop_gradient(logits[:, :, :, :1] - log_weights[:, :, :, :1])
    # For each key, this is the sum of the normalized weights.
    m_total = tf.expand_dims(
        tf.reduce_sum(tf.stop_gradient(weights), axis=2), -1)
    o = tf.matmul(weights, v)
    o = q_dispatcher.combine(o)

    o = tf.reshape(o, [batch, num_heads, length_q, depth_v])
    o = combine_heads(o)
    o = common_layers.dense(
        o, output_depth, use_bias=False, name="output_transform")

    m_total = m_dispatcher.combine(m_total)
    q_total = q_dispatcher.combine(q_total)
    q_total = tf.squeeze(q_total, -1)
    m_total = tf.squeeze(m_total, -1)
    # Compute summed m predictions for all groups
    m_pred_used = tf.reduce_sum(tf.exp(m_pred) * m_dispatcher.gates(), axis=2)
    q_pred_used = tf.reduce_sum(q_pred * q_dispatcher.gates(), axis=2)
    epsilon = 1e-3
    m_pred_used = tf.log(m_pred_used + epsilon)
    m_total = tf.log(m_total + epsilon)
    m_loss = tf.nn.l2_loss(m_total - m_pred_used)
    q_loss = tf.nn.l2_loss(
        (q_total - q_pred_used) * tf.reduce_sum(q_gates, axis=2))

    q_loss /= tf.to_float(batch * length_q)
    m_loss /= tf.to_float(batch * length_kv)

    # We would like the query groups to be equal sized.  The group
    # size is discrete, so we need some trick here.  We add a loss
    # proportional to the product of the group size and the
    # predictions for that group.  This encourages the predictions to
    # decrease for groups that are too big.
    q_group_deviation = (q_group_size / q_group_target_size) - 1.0
    q_balance_loss = tf.reduce_sum(
        tf.reduce_mean(q_pred_biased, axis=1) *
        q_group_deviation) / tf.to_float(batch)
    m_group_deviation = (m_group_size / m_group_target_size) - 1.0
    m_balance_loss = tf.reduce_sum(
        tf.reduce_mean(m_pred_biased, axis=1) *
        m_group_deviation) / tf.to_float(batch)

    # The losses in this function only propagate back to variables
    # defined in this function, and the losses outside of this
    # function only propagate back to variables outside of this
    # function.  Assuming some kind of adaptive learning algorithm,
    # it should not matter how much we scale the losses in this function.
    # Still we scale them down a lot so that they should not show up
    # much in the overall loss for the model.
    extra_loss_multiplier = 1e-3
    extra_loss = q_loss + m_loss + q_balance_loss + m_balance_loss
    extra_loss *= extra_loss_multiplier

    # Show a bunch of summaries.
    if common_layers.should_generate_summaries() and make_image_summary:
      tf.summary.histogram("q_group_size", q_group_size)
      tf.summary.histogram("m_group_size", m_group_size)
      tf.summary.scalar("q_loss", q_loss)
      tf.summary.scalar("m_loss", m_loss)
      tf.summary.scalar("q_balance_loss", q_balance_loss)
      tf.summary.scalar("m_balance_loss", m_balance_loss)
      tf.summary.histogram("m_pred_used", m_pred_used)
      tf.summary.histogram("m_total", m_total)
      tf.summary.histogram("q_pred_used", q_pred_used)
      tf.summary.histogram("q_total", q_total)
      if make_image_summary:
        # image summaries are expensive.
        # So we restrict them to head_num<4, query_position<512, batch_index=0.
        trunc_heads = min(4, num_heads)
        trunc_length_q = tf.minimum(length_q, 512)
        # We recompute the attention for the first example, in an inefficient
        # way - masking.  This lets us show pretty pictures.
        # [trunc_heads, length_q, group]
        q_gates_trunc = q_gates[:trunc_heads, :trunc_length_q, :]
        # [trunc_heads, length_kv, group]
        m_gates_trunc = m_gates[:trunc_heads, :, :]
        grouping_mask = tf.matmul(
            q_gates_trunc, m_gates_trunc, transpose_b=True)
        q_trunc = q[:trunc_heads, :trunc_length_q, :]
        k_trunc = kv[:trunc_heads, :, :depth_qk]
        logits_trunc = tf.matmul(q_trunc, k_trunc, transpose_b=True)
        if mask_right:
          band = common_layers.ones_matrix_band_part(trunc_length_q, length_kv,
                                                     -1, 0)
          trunc_bias = tf.expand_dims((1.0 - band) * -1e9, 0)
          logits_trunc += trunc_bias
        att_trunc = tf.nn.softmax(logits_trunc)
        mask_coverage = tf.reduce_sum(grouping_mask * att_trunc) / (
            tf.to_float(trunc_length_q) * trunc_heads)
        tf.summary.scalar("coverage", mask_coverage)
        att_trunc_hdr = tf.pow(att_trunc, 0.2)  # for high-dynamic-range
        mask_channel = grouping_mask * tf.maximum(att_trunc_hdr, 0.3)
        image = tf.stack([att_trunc_hdr, mask_channel, mask_channel], axis=3)
        tf.summary.image("att", image, max_outputs=trunc_heads)
        # show one group for each head.
        att_per_group = tf.expand_dims(weights[:trunc_heads, 0, :, :], -1)
        tf.summary.image(
            "att_per_group_%d",
            tf.pow(att_per_group, 0.2),
            max_outputs=trunc_heads)
    return o, extra_loss