in ma_policy/layers.py [0:0]
def self_attention(inp, mask, heads, n_embd, layer_norm=False, qk_w=1.0, v_w=0.01,
scope='', reuse=False):
'''
Self attention over entities.
Notation:
T - Time
NE - Number entities
Args:
inp (tf) -- tensor w/ shape (bs, T, NE, features)
mask (tf) -- binary tensor with shape (bs, T, NE). For each batch x time,
nner matrix represents entity i's ability to see entity j
heads (int) -- number of attention heads
n_embd (int) -- dimension of queries, keys, and values will be n_embd / heads
layer_norm (bool) -- normalize embedding prior to computing qkv
qk_w, v_w (float) -- scale for gaussian init for keys/queries and values
Std will be sqrt(scale/n_embd)
scope (string) -- tf scope
reuse (bool) -- tf reuse
'''
with tf.variable_scope(scope, reuse=reuse):
bs, T, NE, features = shape_list(inp)
# Put mask in format correct for logit matrix
entity_mask = None
if mask is not None:
with tf.variable_scope('expand_mask'):
assert np.all(np.array(mask.get_shape().as_list()) == np.array(inp.get_shape().as_list()[:3])),\
f"Mask and input should have the same first 3 dimensions. {shape_list(mask)} -- {shape_list(inp)}"
entity_mask = mask
mask = tf.expand_dims(mask, -2) # (BS, T, 1, NE)
query, key, value = qkv_embed(inp, heads, n_embd, layer_norm=layer_norm, qk_w=qk_w, v_w=v_w, reuse=reuse)
logits = tf.matmul(query, key, name="matmul_qk_parallel") # (bs, T, heads, NE, NE)
logits /= np.sqrt(n_embd / heads)
softmax = stable_masked_softmax(logits, mask)
att_sum = tf.matmul(softmax, value, name="matmul_softmax_value") # (bs, T, heads, NE, features)
with tf.variable_scope('flatten_heads'):
out = tf.transpose(att_sum, (0, 1, 3, 2, 4)) # (bs, T, n_output_entities, heads, features)
n_output_entities = shape_list(out)[2]
out = tf.reshape(out, (bs, T, n_output_entities, n_embd)) # (bs, T, n_output_entities, n_embd)
return out