in train.py [0:0]
def get_blocksparse_obj(n_ctx, n_heads, attn_mode):
'''a is dense attention, b is local attention (previous k),
bT is strided (every kth element), implemented as a transpose'''
key = f'{n_ctx}-{n_heads}-{attn_mode}'
bst = H.bst_cache.get(key)
if bst is not None:
return bst
blocksize = H.blocksize
n_bctx = n_ctx // blocksize
if attn_mode in ['b', 'bT', 'b0']:
if attn_mode in ['b']:
assert H.local_attn_ctx % blocksize == 0
extra_diagonals = H.local_attn_ctx // blocksize
elif attn_mode in ['bT', 'b0']:
bT_ctx = H.attn_ctx // H.local_attn_ctx
assert bT_ctx % blocksize == 0
block_chunks = bT_ctx // blocksize
layout = np.ones([n_bctx, n_bctx], dtype=np.bool)
for q_idx in range(n_bctx):
# Causal queries cannot attend to keys above them
layout[q_idx, q_idx + 1:] = 0
if attn_mode == 'b':
start = max(0, q_idx - extra_diagonals)
layout[q_idx, :start] = 0
elif attn_mode in ['bT', 'b0']:
offset = q_idx % block_chunks
layout[q_idx, :q_idx - offset] = 0
elif attn_mode == 'a':
# standard causal attention
layout = np.ones([n_bctx, n_bctx], dtype=np.bool)
for q_idx in range(n_bctx):
layout[q_idx, q_idx + 1:] = 0
elif attn_mode == 'a_all':
layout = np.ones([n_bctx, n_bctx], dtype=np.bool)
if H.mem_block and H.block_memory:
# Block attention over the memory block
layout[:-1, -1] = 0
elif attn_mode in ['b_all', 'bT_all']:
assert H.blocksize == 32
assert H.local_attn_ctx == 32
assert n_bctx == 32
layout = np.zeros([n_bctx, n_bctx], dtype=np.bool)
for q_idx in range(n_bctx):
layout[q_idx, q_idx] = 1.0
else:
raise NotImplementedError
if H.print_attn_layout:
width = H.attn_cols_to_print
for i in range(min(width, n_bctx)):
print(' '.join([str(x) for x in layout[i, 0:width].astype(np.int32)]))
pdb.set_trace()
bst = bs.BlocksparseTransformer(
layout, block_size=blocksize,
mask_callback=get_callback(attn_mode), heads=n_heads)
H.bst_cache[key] = bst
return bst