def qkv_embed()

in ma_policy/layers.py [0:0]


def qkv_embed(inp, heads, n_embd, layer_norm=False, qk_w=1.0, v_w=0.01, reuse=False):
    '''
        Compute queries, keys, and values
        Args:
            inp (tf) -- tensor w/ shape (bs, T, NE, features)
            heads (int) -- number of attention heads
            n_embd (int) -- dimension of queries, keys, and values will be n_embd / heads
            layer_norm (bool) -- normalize embedding prior to computing qkv
            qk_w (float) -- Initialization scale for keys and queries. Actual scale will be
                sqrt(qk_w / #input features)
            v_w (float) -- Initialization scale for values. Actual scale will be sqrt(v_w / #input features)
            reuse (bool) -- tf reuse
    '''
    with tf.variable_scope('qkv_embed'):
        bs, T, NE, features = shape_list(inp)
        if layer_norm:
            with tf.variable_scope('pre_sa_layer_norm'):
                inp = tf.contrib.layers.layer_norm(inp, begin_norm_axis=3)

        # qk shape (bs x T x NE x h x n_embd/h)
        qk_scale = np.sqrt(qk_w / features)
        qk = tf.layers.dense(inp,
                             n_embd * 2,
                             kernel_initializer=tf.random_normal_initializer(stddev=qk_scale),
                             reuse=reuse,
                             name="qk_embed")  # bs x T x n_embd*2
        qk = tf.reshape(qk, (bs, T, NE, heads, n_embd // heads, 2))

        # (bs, T, NE, heads, features)
        query, key = [tf.squeeze(x, -1) for x in tf.split(qk, 2, -1)]

        v_scale = np.sqrt(v_w / features)
        value = tf.layers.dense(inp,
                                n_embd,
                                kernel_initializer=tf.random_normal_initializer(stddev=v_scale),
                                reuse=reuse,
                                name="v_embed")  # bs x T x n_embd
        value = tf.reshape(value, (bs, T, NE, heads, n_embd // heads))

        query = tf.transpose(query, (0, 1, 3, 2, 4),
                             name="transpose_query")  # (bs, T, heads, NE, n_embd / heads)
        key = tf.transpose(key, (0, 1, 3, 4, 2),
                           name="transpose_key")  # (bs, T, heads, n_embd / heads, NE)
        value = tf.transpose(value, (0, 1, 3, 2, 4),
                             name="transpose_value")  # (bs, T, heads, NE, n_embd / heads)

    return query, key, value