in tensor2tensor/layers/common_attention.py [0:0]
def masked_relative_local_attention_1d(q,
k,
v,
block_length=128,
make_image_summary=False,
dropout_rate=0.,
heads_share_relative_embedding=False,
add_relative_to_values=False,
name=None):
"""Masked local 1d attention with relative positions.
The sequence is divided into blocks of length block_size.
Attention for a given query position can only see memory positions
less than or equal to the query position, in the corresponding block
and the previous block.
If mask_right is True, then a target position cannot see greater source
positions.
Args:
q: a Tensor with shape [batch, heads, length, depth_k]
k: a Tensor with shape [batch, heads, length, depth_k]
v: a Tensor with shape [batch, heads, length, depth_v]
block_length: an integer
make_image_summary: a boolean, whether to make an attention image summary.
dropout_rate: Dropout rate for attention dropout
heads_share_relative_embedding: a boolean for sharing relative embeddings.
add_relative_to_values: a boolean for whether to add relative component to
values.
name: an optional string
Returns:
a Tensor of shape [batch, heads, length, depth_v]
Raises:
ValueError: wwhen the name for the variable scope is not passed.
"""
if not name:
raise ValueError("Name must be assigned since reuse for variable scope is "
"set to tf.AUTO_REUSE, in order to reuse relative "
"embeddings of keys and values.")
# Reuse flag is set to auto_reuse to reuse relative embeddings of keys and
# values across blocks (first and tail blocks).
with tf.variable_scope(
name, default_name="masked_relative_local_attention_1d",
values=[q, k, v], reuse=tf.AUTO_REUSE):
default_block_length = block_length
batch = common_layers.shape_list(q)[0]
heads = common_layers.shape_list(q)[1]
length = common_layers.shape_list(q)[2]
# If (length < 2 * block_length), then we use only one block.
if isinstance(length, int) and isinstance(block_length, int):
block_length = length if length < block_length * 2 else block_length
else:
block_length = tf.where(
tf.less(length, block_length * 2), length, block_length)
depth_k = common_layers.shape_list(k)[3]
depth_v = common_layers.shape_list(v)[3]
original_length = length
padding_size = tf.mod(-length, block_length)
length += padding_size
padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]]
q = tf.pad(q, padding)
k = tf.pad(k, padding)
v = tf.pad(v, padding)
num_blocks = length // block_length
# compute attention for the first query block.
first_q = tf.slice(q, [0, 0, 0, 0], [-1, -1, block_length, -1])
first_k = tf.slice(k, [0, 0, 0, 0], [-1, -1, block_length, -1])
first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1])
# Relative embeddings will be used later as well.
# TODO(avaswani,annahuang): check why 2*bl was breaking for music
# Needs to be known at static shape inference time, hence cannot be
# 2 * block_length.
rel_embed_length = 4 * default_block_length
# We only multiply with the needed embeddings as we slice them out.
first_rel_embeddings = get_relative_embeddings_left(
rel_embed_length, block_length, depth_k, heads,
heads_share_relative_embedding, "relative_embeddings")
first_rel_logits = matmul_with_relative_keys(
first_q, first_rel_embeddings, heads_share_relative_embedding)
first_logits = tf.matmul(first_q, first_k, transpose_b=True)
first_logits += (
_relative_position_to_absolute_position_masked(first_rel_logits))
# adding a mask
first_logits += (
common_layers.cast_like(attention_bias_lower_triangle(block_length),
first_logits))
first_att = tf.nn.softmax(first_logits,
name="first_attention_weights")
# dropping out the attention links for each of the heads
first_att = common_layers.dropout_with_broadcast_dims(
first_att, 1.0 - dropout_rate,
broadcast_dims=None)
# only call image summary for the first block
if common_layers.should_generate_summaries() and make_image_summary:
attention_image_summary(first_att, None)
first_output = tf.matmul(first_att, first_v)
# compute attention for all subsequent query blocks.
q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k])
k = tf.reshape(k, [batch, heads, num_blocks, block_length, depth_k])
v = tf.reshape(v, [batch, heads, num_blocks, block_length, depth_v])
local_k = _make_local_block(k, depth_k, batch, heads, num_blocks,
block_length)
local_v = _make_local_block(v, depth_v, batch, heads, num_blocks,
block_length)
tail_q = tf.slice(q, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
tail_q = tf.reshape(tail_q,
[batch, heads, num_blocks - 1, block_length, depth_k])
local_length = common_layers.shape_list(local_k)[3]
# collapsing num blocks and batch size so that we can reuse
# functions
def _reshape_for_relative(x):
x_shape = common_layers.shape_list(x)
# [batch, num_blocks, heads, length, depth]
x = tf.transpose(x, [0, 2, 1, 3, 4])
x = tf.reshape(x, [batch*x_shape[2], heads, x_shape[3],
x_shape[4]])
return x
rel_tail_q = _reshape_for_relative(tail_q)
rel_k = _reshape_for_relative(local_k)
rel_v = _reshape_for_relative(local_v)
rel_embeddings = get_relative_embeddings_left(
rel_embed_length, 2 * block_length, depth_k, heads,
heads_share_relative_embedding, "relative_embeddings")
rel_logits = matmul_with_relative_keys(
rel_tail_q, rel_embeddings, heads_share_relative_embedding)
# Computing relative logits separately for the masked and unmasked parts
# because the reshaping logic is different for both
masked_rel_logits = tf.slice(rel_logits, [0, 0, 0, block_length],
[-1, -1, -1, -1])
masked_rel_logits = _relative_position_to_absolute_position_masked(
masked_rel_logits)
unmasked_rel_logits = tf.slice(rel_logits, [0, 0, 0, 0],
[-1, -1, -1, 2*block_length-1])
unmasked_rel_logits = _relative_position_to_absolute_position_unmasked(
unmasked_rel_logits)
all_rel_logits = tf.concat([unmasked_rel_logits, masked_rel_logits],
axis=3)
all_logits = (
tf.matmul(rel_tail_q, rel_k, transpose_b=True) + all_rel_logits)
# make sure source_pos <= target_pos
good_part = common_layers.ones_matrix_band_part(block_length,
local_length,
-1, block_length)
mask = (1.0 - good_part) * -1e9
mask = common_layers.cast_like(mask, all_logits)
all_logits += tf.reshape(mask, [1, 1, block_length, local_length])
weights = tf.nn.softmax(all_logits, name="attention_weights")
# [batch (* num_blocks), heads, query_length (=block_length),
# key_length (=2*block_length)]
weights = common_layers.dropout_with_broadcast_dims(
weights, 1.0 - dropout_rate,
broadcast_dims=None)
output = tf.matmul(weights, rel_v)
if add_relative_to_values:
# Adds the contribution of the weighted relative embeddings to the values.
weights_for_unmasked, weights_for_masked = (
tf.split(weights, 2, axis=3))
rel_weights_unmasked = _absolute_position_to_relative_position_unmasked(
weights_for_unmasked)
rel_weights_masked = _absolute_position_to_relative_position_masked(
weights_for_masked)
value_rel_embeddings_unmasked = get_relative_embeddings_left(
rel_embed_length, 2 * block_length, depth_v,
heads, heads_share_relative_embedding,
"value_relative_embeddings")
# The unmasked part starts with index -1 as opposed 0 has take uptil last.
if heads_share_relative_embedding:
value_rel_embeddings_unmasked = value_rel_embeddings_unmasked[:-1, :]
else:
value_rel_embeddings_unmasked = value_rel_embeddings_unmasked[:, :-1, :]
value_rel_embeddings_masked = get_relative_embeddings_left(
rel_embed_length, block_length, depth_v,
heads, heads_share_relative_embedding,
"value_relative_embeddings")
# [batch (*num_blocks), heads, query length, key length]
rel_weights = tf.concat(
[rel_weights_unmasked, rel_weights_masked], axis=3)
if heads_share_relative_embedding:
value_rel_embeddings_concat_axis = 0
else:
value_rel_embeddings_concat_axis = 1
value_rel_embeddings = tf.concat(
[value_rel_embeddings_unmasked, value_rel_embeddings_masked],
axis=value_rel_embeddings_concat_axis)
output_rel = matmul_with_relative_values(
rel_weights, value_rel_embeddings, heads_share_relative_embedding)
output += output_rel
# bring to [batch, heads, num_blocks-1, block_length, depth]
output = tf.reshape(output,
[batch, num_blocks-1, heads, block_length, depth_v])
output = tf.transpose(output, [0, 2, 1, 3, 4])
output = tf.reshape(
output, [batch, heads, (num_blocks - 1) * block_length, depth_v])
output = tf.concat([first_output, output], axis=2)
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
output = tf.reshape(output, [batch, heads, original_length, depth_v])
return output