in tensorflow_addons/layers/multihead_attention.py [0:0]
def call(self, inputs, training=None, mask=None):
# einsum nomenclature
# ------------------------
# N = query elements
# M = key/value elements
# H = heads
# I = input features
# O = output features
query = inputs[0]
key = inputs[1]
value = inputs[2] if len(inputs) > 2 else key
# verify shapes
if key.shape[-2] != value.shape[-2]:
raise ValueError(
"the number of elements in 'key' must be equal to the same as the number of elements in 'value'"
)
if mask is not None:
if len(mask.shape) < 2:
raise ValueError("'mask' must have atleast 2 dimensions")
if query.shape[-2] != mask.shape[-2]:
raise ValueError(
"mask's second to last dimension must be equal to the number of elements in 'query'"
)
if key.shape[-2] != mask.shape[-1]:
raise ValueError(
"mask's last dimension must be equal to the number of elements in 'key'"
)
# Linear transformations
query = tf.einsum("...NI , HIO -> ...NHO", query, self.query_kernel)
key = tf.einsum("...MI , HIO -> ...MHO", key, self.key_kernel)
value = tf.einsum("...MI , HIO -> ...MHO", value, self.value_kernel)
# Scale dot-product, doing the division to either query or key
# instead of their product saves some computation
depth = tf.constant(self.head_size, dtype=query.dtype)
query /= tf.sqrt(depth)
# Calculate dot product attention
logits = tf.einsum("...NHO,...MHO->...HNM", query, key)
# apply mask
if mask is not None:
mask = tf.cast(mask, tf.float32)
# possibly expand on the head dimension so broadcasting works
if len(mask.shape) != len(logits.shape):
mask = tf.expand_dims(mask, -3)
logits += -10e9 * (1.0 - mask)
attn_coef = tf.nn.softmax(logits)
# attention dropout
attn_coef_dropout = self.dropout(attn_coef, training=training)
# attention * value
multihead_output = tf.einsum("...HNM,...MHI->...NHI", attn_coef_dropout, value)
# Run the outputs through another linear projection layer. Recombining heads
# is automatically done.
output = tf.einsum(
"...NHI,HIO->...NO", multihead_output, self.projection_kernel
)
if self.projection_bias is not None:
output += self.projection_bias
if self.return_attn_coef:
return output, attn_coef
else:
return output