in ma_policy/layers.py [0:0]
def residual_sa_block(inp, mask, heads, n_embd,
layer_norm=False, post_sa_layer_norm=False,
n_mlp=1, qk_w=0.125, v_w=0.125, post_w=0.125,
mlp_w1=0.125, mlp_w2=0.125,
scope="residual_sa_block", reuse=False):
'''
Residual self attention block for entities.
Notation:
T - Time
NE - Number entities
Args:
inp (tf): (BS, T, NE, f)
mask (tf): (BS, T, NE)
heads (int) -- number of attention heads
n_embd (int) -- dimension of queries, keys, and values will be n_embd / heads
layer_norm (bool) -- normalize embedding prior to computing qkv
n_mlp (int) -- number of mlp layers. If there are more than 1 mlp layers, we'll add a residual
connection from after the first mlp to after the last mlp.
qk_w, v_w, post_w, mlp_w1, mlp_w2 (float) -- scale for gaussian init for keys/queries, values, mlp
post self attention, second mlp, and third mlp, respectively. Std will be sqrt(scale/n_embd)
scope (string) -- tf scope
reuse (bool) -- tf reuse
'''
with tf.variable_scope(scope, reuse=reuse):
a = self_attention(inp, mask, heads, n_embd, layer_norm=layer_norm, qk_w=qk_w, v_w=v_w,
scope='self_attention', reuse=reuse)
post_scale = np.sqrt(post_w / n_embd)
post_a_mlp = tf.layers.dense(a,
n_embd,
kernel_initializer=tf.random_normal_initializer(stddev=post_scale),
name="mlp1")
x = inp + post_a_mlp
if post_sa_layer_norm:
with tf.variable_scope('post_a_layernorm'):
x = tf.contrib.layers.layer_norm(x, begin_norm_axis=3)
if n_mlp > 1:
mlp = x
mlp2_scale = np.sqrt(mlp_w1 / n_embd)
mlp = tf.layers.dense(mlp,
n_embd,
kernel_initializer=tf.random_normal_initializer(stddev=mlp2_scale),
name="mlp2")
if n_mlp > 2:
mlp3_scale = np.sqrt(mlp_w2 / n_embd)
mlp = tf.layers.dense(mlp,
n_embd,
kernel_initializer=tf.random_normal_initializer(stddev=mlp3_scale),
name="mlp3")
if n_mlp > 1:
x = x + mlp
return x