def residual_sa_block()

in ma_policy/layers.py [0:0]


def residual_sa_block(inp, mask, heads, n_embd,
                      layer_norm=False, post_sa_layer_norm=False,
                      n_mlp=1, qk_w=0.125, v_w=0.125, post_w=0.125,
                      mlp_w1=0.125, mlp_w2=0.125,
                      scope="residual_sa_block", reuse=False):
    '''
        Residual self attention block for entities.
        Notation:
            T  - Time
            NE - Number entities
        Args:
            inp (tf): (BS, T, NE, f)
            mask (tf): (BS, T, NE)
            heads (int) -- number of attention heads
            n_embd (int) -- dimension of queries, keys, and values will be n_embd / heads
            layer_norm (bool) -- normalize embedding prior to computing qkv
            n_mlp (int) -- number of mlp layers. If there are more than 1 mlp layers, we'll add a residual
                connection from after the first mlp to after the last mlp.
            qk_w, v_w, post_w, mlp_w1, mlp_w2 (float) -- scale for gaussian init for keys/queries, values, mlp
                post self attention, second mlp, and third mlp, respectively. Std will be sqrt(scale/n_embd)
            scope (string) -- tf scope
            reuse (bool) -- tf reuse
    '''
    with tf.variable_scope(scope, reuse=reuse):
        a = self_attention(inp, mask, heads, n_embd, layer_norm=layer_norm, qk_w=qk_w, v_w=v_w,
                           scope='self_attention', reuse=reuse)
        post_scale = np.sqrt(post_w / n_embd)
        post_a_mlp = tf.layers.dense(a,
                                     n_embd,
                                     kernel_initializer=tf.random_normal_initializer(stddev=post_scale),
                                     name="mlp1")
        x = inp + post_a_mlp
        if post_sa_layer_norm:
            with tf.variable_scope('post_a_layernorm'):
                x = tf.contrib.layers.layer_norm(x, begin_norm_axis=3)
        if n_mlp > 1:
            mlp = x
            mlp2_scale = np.sqrt(mlp_w1 / n_embd)
            mlp = tf.layers.dense(mlp,
                                  n_embd,
                                  kernel_initializer=tf.random_normal_initializer(stddev=mlp2_scale),
                                  name="mlp2")
        if n_mlp > 2:
            mlp3_scale = np.sqrt(mlp_w2 / n_embd)
            mlp = tf.layers.dense(mlp,
                                  n_embd,
                                  kernel_initializer=tf.random_normal_initializer(stddev=mlp3_scale),
                                  name="mlp3")
        if n_mlp > 1:
            x = x + mlp
        return x