in tensor2tensor/layers/common_attention.py [0:0]
def grouped_attention_multihead(query_antecedent,
memory_antecedent,
total_key_depth,
total_value_depth,
output_depth,
num_heads,
num_groups,
memory_target_density=2.0,
multiplicative_overhead=1.25,
additive_overhead=8.0,
mask_right=False,
make_image_summary=True,
name=None):
"""Multi-head dot-product attention with sparsity.
For each attention head, the queries are partitioned into groups.
For each group, only a subset of the key-value pairs are considered.
The choices of groups are selected based on trained predictors of
the total attention given the group inclusion.
memory_target_density indicates the average how many groups in which
a key-value pair should participate.
We use auxiliary losses to ensure that each group contains roughly
the same number of queries and the same number of key-value pairs.
If for a given sequence, the actual number of queries/pairs sent to
an expert exceeds this target by a factor of more than
multiplicative_overhead, then the last ones are dropped. We use
this drop-last policy to avoid bleeding information backwards, which
is necessary when using this function with autoregressive
prediction.
Args:
query_antecedent: a Tensor with shape [batch, length_q, channels]
memory_antecedent: a Tensor with shape [batch, length_m, channels]
total_key_depth: an integer
total_value_depth: an integer
output_depth: an integer
num_heads: an integer dividing total_key_depth and total_value_depth
num_groups: an integer
memory_target_density: a floating point scalar
multiplicative_overhead: a floating point scalar
additive_overhead: a floating point scalar
mask_right: a boolean
make_image_summary: a boolean
name: an optional string
Returns:
A Tensor with shape [batch, length_q, output_depth]
Raises:
ValueError: if the key depth or value depth are not divisible by the
number of attention heads.
"""
batch = common_layers.shape_list(query_antecedent)[0]
length_q = common_layers.shape_list(query_antecedent)[1]
length_kv = common_layers.shape_list(memory_antecedent)[1]
if total_key_depth % num_heads != 0:
raise ValueError("Key depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_key_depth, num_heads))
depth_qk = total_key_depth // num_heads
if total_value_depth % num_heads != 0:
raise ValueError("Value depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_value_depth, num_heads))
depth_v = total_value_depth // num_heads
with tf.variable_scope(
name, default_name="multihead_attention_sparse",
values=[query_antecedent, memory_antecedent]):
q = common_layers.dense(
query_antecedent, total_key_depth, use_bias=False, name="q_transform")
kv = common_layers.dense(
memory_antecedent,
total_key_depth + total_value_depth,
use_bias=False,
name="kv_transform")
q = split_heads(q, num_heads)
kv = split_heads(kv, num_heads)
# Make predictions about q_total and m_total.
# These are used to determine group inclusion.
# We will train these by auxiliary losses. We use stop_gradient here
# to keep these losses from back-propagating to the rest of the model.
# We add biases that help balance the usage of the experts.
q_pred = common_layers.dense(
tf.stop_gradient(query_antecedent),
num_heads * num_groups,
use_bias=False,
name="q_pred")
q_pred = split_heads(q_pred, num_heads)
q_bias = tf.get_variable("q_bias", [1, num_heads, 1, num_groups])
q_pred_biased = q_pred + q_bias
m_pred = common_layers.dense(
tf.stop_gradient(memory_antecedent),
num_heads * num_groups,
use_bias=False,
name="m_pred")
m_pred = split_heads(m_pred, num_heads)
m_bias = tf.get_variable("m_bias", [1, num_heads, 1, num_groups])
m_pred_biased = m_pred + m_bias
q *= depth_qk**-0.5
# q, kv, q_pred, m_pred are all [batch, heads, length_[q/m], ?]
# now reshape them all to [batch * heads, length, ?]
q = combine_first_two_dimensions(q)
kv = combine_first_two_dimensions(kv)
q_pred = combine_first_two_dimensions(q_pred)
m_pred = combine_first_two_dimensions(m_pred)
q_pred_biased = combine_first_two_dimensions(q_pred_biased)
m_pred_biased = combine_first_two_dimensions(m_pred_biased)
q_group = tf.argmax(q_pred_biased, axis=2)
q_requests = tf.one_hot(q_group, num_groups, axis=-1)
m_requests = tf.to_float(tf.greater(m_pred_biased, 0.0))
# include first memory position in all groups, to avoid division by zero.
m_requests = tf.maximum(
m_requests, tf.reshape(tf.one_hot([0], length_kv), [1, length_kv, 1]))
q_group_size = tf.reduce_sum(q_requests, 1)
m_group_size = tf.reduce_sum(m_requests, 1)
q_group_target_size = tf.to_float(length_q) / tf.to_float(num_groups)
m_group_target_size = (
tf.to_float(length_kv) * memory_target_density /
tf.to_float(num_groups))
capacity_q = tf.minimum(
length_q,
tf.to_int32(q_group_target_size * multiplicative_overhead +
additive_overhead))
capacity_m = tf.minimum(
length_kv,
tf.to_int32(m_group_target_size * multiplicative_overhead +
additive_overhead))
q_dispatcher = expert_utils.TruncatingDispatcher(q_requests, capacity_q)
m_dispatcher = expert_utils.TruncatingDispatcher(m_requests, capacity_m)
q_gates = q_dispatcher.gates()
m_gates = m_dispatcher.gates()
dispatched_q = q_dispatcher.dispatch(q)
dispatched_kv = m_dispatcher.dispatch(kv)
# dispatched_q: [batch * num_heads, num_groups, capacity_q, depth_qk]
# dispatched_kv:
# [batch * num_heads, num_groups, capacity_m, depth_qk + depth_v]
k, v = tf.split(dispatched_kv, [depth_qk, depth_v], axis=3)
logits = tf.matmul(dispatched_q, k, transpose_b=True)
bias = tf.expand_dims((m_dispatcher.nonpadding() - 1.0) * 1e9, 2)
if mask_right:
q_coordinate = tf.to_float(
tf.expand_dims(q_dispatcher.length_coordinate(), 3))
m_coordinate = tf.to_float(
tf.expand_dims(m_dispatcher.length_coordinate(), 2))
bias += tf.to_float(tf.greater(m_coordinate, q_coordinate)) * -1e9
logits += bias
log_weights = tf.nn.log_softmax(logits)
weights = tf.exp(log_weights)
# For each query, this is the log of the sum of the unnormalized weights.
q_total = tf.stop_gradient(logits[:, :, :, :1] - log_weights[:, :, :, :1])
# For each key, this is the sum of the normalized weights.
m_total = tf.expand_dims(
tf.reduce_sum(tf.stop_gradient(weights), axis=2), -1)
o = tf.matmul(weights, v)
o = q_dispatcher.combine(o)
o = tf.reshape(o, [batch, num_heads, length_q, depth_v])
o = combine_heads(o)
o = common_layers.dense(
o, output_depth, use_bias=False, name="output_transform")
m_total = m_dispatcher.combine(m_total)
q_total = q_dispatcher.combine(q_total)
q_total = tf.squeeze(q_total, -1)
m_total = tf.squeeze(m_total, -1)
# Compute summed m predictions for all groups
m_pred_used = tf.reduce_sum(tf.exp(m_pred) * m_dispatcher.gates(), axis=2)
q_pred_used = tf.reduce_sum(q_pred * q_dispatcher.gates(), axis=2)
epsilon = 1e-3
m_pred_used = tf.log(m_pred_used + epsilon)
m_total = tf.log(m_total + epsilon)
m_loss = tf.nn.l2_loss(m_total - m_pred_used)
q_loss = tf.nn.l2_loss(
(q_total - q_pred_used) * tf.reduce_sum(q_gates, axis=2))
q_loss /= tf.to_float(batch * length_q)
m_loss /= tf.to_float(batch * length_kv)
# We would like the query groups to be equal sized. The group
# size is discrete, so we need some trick here. We add a loss
# proportional to the product of the group size and the
# predictions for that group. This encourages the predictions to
# decrease for groups that are too big.
q_group_deviation = (q_group_size / q_group_target_size) - 1.0
q_balance_loss = tf.reduce_sum(
tf.reduce_mean(q_pred_biased, axis=1) *
q_group_deviation) / tf.to_float(batch)
m_group_deviation = (m_group_size / m_group_target_size) - 1.0
m_balance_loss = tf.reduce_sum(
tf.reduce_mean(m_pred_biased, axis=1) *
m_group_deviation) / tf.to_float(batch)
# The losses in this function only propagate back to variables
# defined in this function, and the losses outside of this
# function only propagate back to variables outside of this
# function. Assuming some kind of adaptive learning algorithm,
# it should not matter how much we scale the losses in this function.
# Still we scale them down a lot so that they should not show up
# much in the overall loss for the model.
extra_loss_multiplier = 1e-3
extra_loss = q_loss + m_loss + q_balance_loss + m_balance_loss
extra_loss *= extra_loss_multiplier
# Show a bunch of summaries.
if common_layers.should_generate_summaries() and make_image_summary:
tf.summary.histogram("q_group_size", q_group_size)
tf.summary.histogram("m_group_size", m_group_size)
tf.summary.scalar("q_loss", q_loss)
tf.summary.scalar("m_loss", m_loss)
tf.summary.scalar("q_balance_loss", q_balance_loss)
tf.summary.scalar("m_balance_loss", m_balance_loss)
tf.summary.histogram("m_pred_used", m_pred_used)
tf.summary.histogram("m_total", m_total)
tf.summary.histogram("q_pred_used", q_pred_used)
tf.summary.histogram("q_total", q_total)
if make_image_summary:
# image summaries are expensive.
# So we restrict them to head_num<4, query_position<512, batch_index=0.
trunc_heads = min(4, num_heads)
trunc_length_q = tf.minimum(length_q, 512)
# We recompute the attention for the first example, in an inefficient
# way - masking. This lets us show pretty pictures.
# [trunc_heads, length_q, group]
q_gates_trunc = q_gates[:trunc_heads, :trunc_length_q, :]
# [trunc_heads, length_kv, group]
m_gates_trunc = m_gates[:trunc_heads, :, :]
grouping_mask = tf.matmul(
q_gates_trunc, m_gates_trunc, transpose_b=True)
q_trunc = q[:trunc_heads, :trunc_length_q, :]
k_trunc = kv[:trunc_heads, :, :depth_qk]
logits_trunc = tf.matmul(q_trunc, k_trunc, transpose_b=True)
if mask_right:
band = common_layers.ones_matrix_band_part(trunc_length_q, length_kv,
-1, 0)
trunc_bias = tf.expand_dims((1.0 - band) * -1e9, 0)
logits_trunc += trunc_bias
att_trunc = tf.nn.softmax(logits_trunc)
mask_coverage = tf.reduce_sum(grouping_mask * att_trunc) / (
tf.to_float(trunc_length_q) * trunc_heads)
tf.summary.scalar("coverage", mask_coverage)
att_trunc_hdr = tf.pow(att_trunc, 0.2) # for high-dynamic-range
mask_channel = grouping_mask * tf.maximum(att_trunc_hdr, 0.3)
image = tf.stack([att_trunc_hdr, mask_channel, mask_channel], axis=3)
tf.summary.image("att", image, max_outputs=trunc_heads)
# show one group for each head.
att_per_group = tf.expand_dims(weights[:trunc_heads, 0, :, :], -1)
tf.summary.image(
"att_per_group_%d",
tf.pow(att_per_group, 0.2),
max_outputs=trunc_heads)
return o, extra_loss