in ma_policy/layers.py [0:0]
def qkv_embed(inp, heads, n_embd, layer_norm=False, qk_w=1.0, v_w=0.01, reuse=False):
'''
Compute queries, keys, and values
Args:
inp (tf) -- tensor w/ shape (bs, T, NE, features)
heads (int) -- number of attention heads
n_embd (int) -- dimension of queries, keys, and values will be n_embd / heads
layer_norm (bool) -- normalize embedding prior to computing qkv
qk_w (float) -- Initialization scale for keys and queries. Actual scale will be
sqrt(qk_w / #input features)
v_w (float) -- Initialization scale for values. Actual scale will be sqrt(v_w / #input features)
reuse (bool) -- tf reuse
'''
with tf.variable_scope('qkv_embed'):
bs, T, NE, features = shape_list(inp)
if layer_norm:
with tf.variable_scope('pre_sa_layer_norm'):
inp = tf.contrib.layers.layer_norm(inp, begin_norm_axis=3)
# qk shape (bs x T x NE x h x n_embd/h)
qk_scale = np.sqrt(qk_w / features)
qk = tf.layers.dense(inp,
n_embd * 2,
kernel_initializer=tf.random_normal_initializer(stddev=qk_scale),
reuse=reuse,
name="qk_embed") # bs x T x n_embd*2
qk = tf.reshape(qk, (bs, T, NE, heads, n_embd // heads, 2))
# (bs, T, NE, heads, features)
query, key = [tf.squeeze(x, -1) for x in tf.split(qk, 2, -1)]
v_scale = np.sqrt(v_w / features)
value = tf.layers.dense(inp,
n_embd,
kernel_initializer=tf.random_normal_initializer(stddev=v_scale),
reuse=reuse,
name="v_embed") # bs x T x n_embd
value = tf.reshape(value, (bs, T, NE, heads, n_embd // heads))
query = tf.transpose(query, (0, 1, 3, 2, 4),
name="transpose_query") # (bs, T, heads, NE, n_embd / heads)
key = tf.transpose(key, (0, 1, 3, 4, 2),
name="transpose_key") # (bs, T, heads, n_embd / heads, NE)
value = tf.transpose(value, (0, 1, 3, 2, 4),
name="transpose_value") # (bs, T, heads, NE, n_embd / heads)
return query, key, value