in train.py [0:0]
def sparse_attention(x, n_heads, attn_mode, use_cache=False, train=False, pdrop=None):
if use_cache:
raise NotImplementedError
if not H.float16:
raise ValueError("sparse_attention requires fp16")
nx = x.shape[-1].value
n_state = int(nx * H.qk_ratio)
if n_state % n_heads != 0:
raise ValueError('nx must be divisible by head state')
h = norm("attn_input", x)
if attn_mode in ['bT', 'bT_all']:
ctx = H.local_attn_ctx
bT_ctx = H.attn_ctx // ctx
assert bT_ctx % H.blocksize == 0, f'{bT_ctx}, {H.blocksize}'
n, t, embd = shape_list(h)
h = tf.reshape(h, [n, bT_ctx, ctx, embd])
h = bs.transpose_0213(h)
h = tf.reshape(h, [n, t, embd])
q = linear('q_proj', h, n_state, std=np.sqrt(H.qk_w / nx))
k = linear('k_proj', h, n_state, std=np.sqrt(H.qk_w / nx))
v = linear('v_proj', h, nx, std=np.sqrt(H.v_w / nx))
bst = get_blocksparse_obj(H.attn_ctx, n_heads, attn_mode)
w = bst.query_key_op(q, k)
w = bst.masked_softmax(w, scale=1.0 / np.sqrt(n_state // n_heads))
a = bst.weight_value_op(w, v)
if attn_mode in ['bT', 'bT_all']:
a = tf.reshape(a, [n, ctx, bT_ctx, embd])
a = bs.transpose_0213(a)
a = tf.reshape(a, [n, t, embd])
return post_attention(x, a, train=train, pdrop=pdrop)