def self_attention()

in ma_policy/layers.py [0:0]


def self_attention(inp, mask, heads, n_embd, layer_norm=False, qk_w=1.0, v_w=0.01,
                   scope='', reuse=False):
    '''
        Self attention over entities.
        Notation:
            T  - Time
            NE - Number entities
        Args:
            inp (tf) -- tensor w/ shape (bs, T, NE, features)
            mask (tf) -- binary tensor with shape (bs, T, NE). For each batch x time,
                            nner matrix represents entity i's ability to see entity j
            heads (int) -- number of attention heads
            n_embd (int) -- dimension of queries, keys, and values will be n_embd / heads
            layer_norm (bool) -- normalize embedding prior to computing qkv
            qk_w, v_w (float) -- scale for gaussian init for keys/queries and values
                Std will be sqrt(scale/n_embd)
            scope (string) -- tf scope
            reuse (bool) -- tf reuse
    '''
    with tf.variable_scope(scope, reuse=reuse):
        bs, T, NE, features = shape_list(inp)
        # Put mask in format correct for logit matrix
        entity_mask = None
        if mask is not None:
            with tf.variable_scope('expand_mask'):
                assert np.all(np.array(mask.get_shape().as_list()) == np.array(inp.get_shape().as_list()[:3])),\
                    f"Mask and input should have the same first 3 dimensions. {shape_list(mask)} -- {shape_list(inp)}"
                entity_mask = mask
                mask = tf.expand_dims(mask, -2)  # (BS, T, 1, NE)

        query, key, value = qkv_embed(inp, heads, n_embd, layer_norm=layer_norm, qk_w=qk_w, v_w=v_w, reuse=reuse)
        logits = tf.matmul(query, key, name="matmul_qk_parallel")  # (bs, T, heads, NE, NE)
        logits /= np.sqrt(n_embd / heads)
        softmax = stable_masked_softmax(logits, mask)
        att_sum = tf.matmul(softmax, value, name="matmul_softmax_value")  # (bs, T, heads, NE, features)
        with tf.variable_scope('flatten_heads'):
            out = tf.transpose(att_sum, (0, 1, 3, 2, 4))  # (bs, T, n_output_entities, heads, features)
            n_output_entities = shape_list(out)[2]
            out = tf.reshape(out, (bs, T, n_output_entities, n_embd))  # (bs, T, n_output_entities, n_embd)

        return out