def sparse_attention()

in train.py [0:0]


def sparse_attention(x, n_heads, attn_mode, use_cache=False, train=False, pdrop=None):

    if use_cache:
        raise NotImplementedError
    if not H.float16:
        raise ValueError("sparse_attention requires fp16")

    nx = x.shape[-1].value
    n_state = int(nx * H.qk_ratio)
    if n_state % n_heads != 0:
        raise ValueError('nx must be divisible by head state')

    h = norm("attn_input", x)

    if attn_mode in ['bT', 'bT_all']:
        ctx = H.local_attn_ctx
        bT_ctx = H.attn_ctx // ctx
        assert bT_ctx % H.blocksize == 0, f'{bT_ctx}, {H.blocksize}'
        n, t, embd = shape_list(h)
        h = tf.reshape(h, [n, bT_ctx, ctx, embd])
        h = bs.transpose_0213(h)
        h = tf.reshape(h, [n, t, embd])

    q = linear('q_proj', h, n_state, std=np.sqrt(H.qk_w / nx))
    k = linear('k_proj', h, n_state, std=np.sqrt(H.qk_w / nx))
    v = linear('v_proj', h, nx, std=np.sqrt(H.v_w / nx))

    bst = get_blocksparse_obj(H.attn_ctx, n_heads, attn_mode)

    w = bst.query_key_op(q, k)
    w = bst.masked_softmax(w, scale=1.0 / np.sqrt(n_state // n_heads))
    a = bst.weight_value_op(w, v)

    if attn_mode in ['bT', 'bT_all']:
        a = tf.reshape(a, [n, ctx, bT_ctx, embd])
        a = bs.transpose_0213(a)
        a = tf.reshape(a, [n, t, embd])

    return post_attention(x, a, train=train, pdrop=pdrop)