in mesh_tensorflow/transformer/attention.py [0:0]
def synthetic_attention(q,
k,
v,
memory_length_dim,
key_dim,
value_dim,
bias=None,
dropout_rate=0.0,
dropout_broadcast_dims=None,
extra_logit=None,
synthesize=True,
synthesize_mode="random_plus_alpha",
factorized_dim=16,
max_length=512,
context=None):
"""Synthetic Attention from Synthesizers (https://arxiv.org/abs/2005.00743).
key_dim is a Dimension representing the channels in the queries and keys
value_dim is a Dimension representing the channels in values
memory_length_dim is a Dimension representing the different key/value pairs.
Dimensions of q: other_query_dims + {key_dim}
Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
other_memory_dims is a subset of other_query_dims
Typically, other_query_dims={batch, heads, length}
Typically, other_memory_dims={batch, heads}
Args:
q: a Tensor
k: a Tensor
v: a Tensor
memory_length_dim: a Dimension
key_dim: a Dimension
value_dim: a Dimension
bias: a Tensor to be added into the attention logits.
dropout_rate: a float.
dropout_broadcast_dims: an optional list of mtf.Dimension
extra_logit: an optional scalar or tensor
synthesize: flag to use synthetic attention or not
synthesize_mode: which variant of synthesizer to use
factorized_dim: factorized dim for synthesizers
max_length: max length of input sequence
context: context since we need context mode
Returns:
Tensor with shape q.shape - key_dim + value_dim
"""
if synthesize:
num_heads = v.shape.get_dim_by_name("heads")
tf.logging.info("Using synthesizer")
if synthesize_mode == "random":
tf.logging.info("Using Random Synthesizers")
r_shape = mtf.Shape([mtf.Dimension("length", max_length),
mtf.Dimension("heads", num_heads.size),
mtf.Dimension("memory_length", max_length)])
r = mtf.get_variable(context.mesh, "R", r_shape,
initializer=None,
dtype=context.variable_dtype)
r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
if context.mode == "incremental":
r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length"))
else:
length_dim = q.shape.get_dim_by_name("length")
r = mtf.slice(r, 0, length_dim.size, "length")
logits = r
r_shape = logits.shape
elif synthesize_mode == "factorized":
tf.logging.info("Using Factorized Random Synthesizers")
k = factorized_dim
r1_shape = mtf.Shape([mtf.Dimension("tmp", k),
mtf.Dimension("heads", num_heads.size),
mtf.Dimension("memory_length", 512)])
r2_shape = mtf.Shape([mtf.Dimension("tmp", k),
mtf.Dimension("heads", num_heads.size),
mtf.Dimension("memory_length", 512)])
r_shape = mtf.Shape([mtf.Dimension("length", 512),
mtf.Dimension("heads", num_heads.size),
mtf.Dimension("memory_length", 512)])
r1 = mtf.get_variable(context.mesh, "R1", r1_shape,
initializer=None,
dtype=context.variable_dtype)
r2 = mtf.get_variable(context.mesh, "R2", r2_shape,
initializer=None,
dtype=context.variable_dtype)
r = mtf.einsum([r1, r2], r_shape)
r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
if context.mode == "incremental":
r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length"))
else:
length_dim = q.shape.get_dim_by_name("length")
r = mtf.slice(r, 0, length_dim.size, "length")
logits = r
elif synthesize_mode == "dense_minus":
# Dense Synthesizer Model
tmp_dim = mtf.Dimension("memory_length", max_length)
logits = mtf.layers.dense(mtf.relu(q), [tmp_dim],
use_bias=False,
name="pi",
reduced_dims=[key_dim],
variable_dtype=None)
logits = mtf.slice(logits, 0, memory_length_dim.size,
memory_length_dim.name)
if context.mode == "incremental":
pass
else:
length_dim = q.shape.get_dim_by_name("length")
logits = mtf.slice(logits, 0, length_dim.size, "length")
elif synthesize_mode == "random_plus_alpha" or \
synthesize_mode == "random_plus":
# Mixture Random Synthesizer with learnable Alpha
tf.logging.info("Using Random Plus Alpha")
logits = mtf.einsum([q, k], reduced_dims=[key_dim])
num_heads = logits.shape.get_dim_by_name("heads")
r_shape = mtf.Shape([mtf.Dimension("length", 512),
mtf.Dimension("heads", num_heads.size),
mtf.Dimension("memory_length", 512)])
r = mtf.get_variable(context.mesh, "R", r_shape,
initializer=None,
dtype=context.variable_dtype)
r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
if context.mode == "incremental":
r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length"))
else:
length_dim = q.shape.get_dim_by_name("length")
r = mtf.slice(r, 0, length_dim.size, length_dim.name)
if "alpha" in synthesize_mode:
alpha = mtf.get_variable(context.mesh,
"alpha",
mtf.Shape([mtf.Dimension("alpha", 1)]),
initializer=tf.zeros_initializer(),
dtype=context.variable_dtype)
alpha = mtf.sigmoid(alpha)
logits = ((1-alpha) * logits) + (alpha * r)
else:
logits = logits + r
elif synthesize_mode == "dense_plus_alpha" or \
synthesize_mode == "dense_plus":
# Mixture Dense Synthesizer with learnable alpha
tf.logging.info("Using Dense Plus Alpha Scaling")
logits = mtf.einsum([q, k], reduced_dims=[key_dim])
tmp_dim = mtf.Dimension("memory_length", 512)
r = mtf.layers.dense(mtf.relu(q), [tmp_dim],
use_bias=False,
name="pi",
reduced_dims=[key_dim],
variable_dtype=None)
r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
if context.mode == "incremental":
pass
else:
length_dim = q.shape.get_dim_by_name("length")
r = mtf.slice(r, 0, length_dim.size, "length")
if "alpha" in synthesize_mode:
alpha = mtf.get_variable(context.mesh,
"alpha",
mtf.Shape([mtf.Dimension("alpha", 1)]),
initializer=tf.zeros_initializer(),
dtype=context.variable_dtype)
alpha = mtf.sigmoid(alpha)
logits = ((1-alpha) * logits) + (alpha * r)
else:
logits = logits + r
if bias is not None:
logits += bias
weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit)
weights = mtf.dropout(
weights, context.train, 1.0 - dropout_rate,
noise_shape=weights.shape - dropout_broadcast_dims)
if synthesize and "plus" not in synthesize_mode:
if synthesize_mode == "dense_minus":
outputs_shape = mtf.Shape(q.shape.dims[:-1] + [value_dim])
else:
outputs_shape = mtf.Shape(q.shape.dims[:-1] + [num_heads, value_dim])
else:
outputs_shape = q.shape - [key_dim] + value_dim
outputs = mtf.einsum([weights, v], outputs_shape)
return outputs