def dense_attention()

in train.py [0:0]


def dense_attention(x, n_heads, attn_mode, use_cache=False, train=False, pdrop=None):

    nx = x.shape[-1].value
    n_state = int(nx * H.qk_ratio)
    if n_state % n_heads != 0:
        raise ValueError('nx must be divisible by head state')

    h = norm("attn_input", x)

    qh = h[:, -1:, :] if use_cache else h

    q = linear('q_proj', qh, n_state, std=np.sqrt(H.qk_w / nx))
    k = linear('k_proj', h, n_state, std=np.sqrt(H.qk_w / nx))
    v = linear('v_proj', h, nx, std=np.sqrt(H.v_w / nx))

    q = split_heads("q_split", q, n_heads)
    k = split_heads("k_split", k, n_heads)
    v = split_heads("v_split", v, n_heads)

    if use_cache:
        if attn_mode not in ['a', 'b', 'c', 'bT']:
            raise NotImplementedError
        mask = None
        if attn_mode == 'b':
            k = k[:, :, -H.local_attn_ctx:, :]
            v = v[:, :, -H.local_attn_ctx:, :]
        elif attn_mode in ['c', 'bT']:
            k = k[:, :, ::-H.local_attn_ctx, :][:, :, ::-1, :]
            v = v[:, :, ::-H.local_attn_ctx, :][:, :, ::-1, :]
    else:
        n_timesteps = k.shape[2].value
        mask = get_dense_attn_mask(n_timesteps, attn_mode)
    if H.float16:
        # These products can overflow, so we do it in float32.
        k = bs.float_cast(k, dtype=tf.float32)
        q = bs.float_cast(q, dtype=tf.float32)
        v = bs.float_cast(v, dtype=tf.float32)
    w = tf.matmul(q, k, transpose_b=True)
    w = bs.masked_softmax(w, mask=mask, scale=1.0 / np.sqrt(q.shape[-1].value))
    a = tf.matmul(w, v)
    a = merge_heads("merge_attn", a)
    if H.float16:
        a = bs.float_cast(a, dtype=tf.float16)

    return post_attention(x, a, use_cache=use_cache, train=train, pdrop=pdrop)