def get_blocksparse_obj()

in train.py [0:0]


def get_blocksparse_obj(n_ctx, n_heads, attn_mode):
    '''a is dense attention, b is local attention (previous k),
    bT is strided (every kth element), implemented as a transpose'''
    key = f'{n_ctx}-{n_heads}-{attn_mode}'
    bst = H.bst_cache.get(key)
    if bst is not None:
        return bst

    blocksize = H.blocksize
    n_bctx = n_ctx // blocksize

    if attn_mode in ['b', 'bT', 'b0']:
        if attn_mode in ['b']:
            assert H.local_attn_ctx % blocksize == 0
            extra_diagonals = H.local_attn_ctx // blocksize
        elif attn_mode in ['bT', 'b0']:
            bT_ctx = H.attn_ctx // H.local_attn_ctx
            assert bT_ctx % blocksize == 0
            block_chunks = bT_ctx // blocksize
        layout = np.ones([n_bctx, n_bctx], dtype=np.bool)
        for q_idx in range(n_bctx):
            # Causal queries cannot attend to keys above them
            layout[q_idx, q_idx + 1:] = 0
            if attn_mode == 'b':
                start = max(0, q_idx - extra_diagonals)
                layout[q_idx, :start] = 0
            elif attn_mode in ['bT', 'b0']:
                offset = q_idx % block_chunks
                layout[q_idx, :q_idx - offset] = 0

    elif attn_mode == 'a':
        # standard causal attention
        layout = np.ones([n_bctx, n_bctx], dtype=np.bool)
        for q_idx in range(n_bctx):
            layout[q_idx, q_idx + 1:] = 0
    elif attn_mode == 'a_all':
        layout = np.ones([n_bctx, n_bctx], dtype=np.bool)
        if H.mem_block and H.block_memory:
            # Block attention over the memory block
            layout[:-1, -1] = 0
    elif attn_mode in ['b_all', 'bT_all']:
        assert H.blocksize == 32
        assert H.local_attn_ctx == 32
        assert n_bctx == 32
        layout = np.zeros([n_bctx, n_bctx], dtype=np.bool)
        for q_idx in range(n_bctx):
            layout[q_idx, q_idx] = 1.0
    else:
        raise NotImplementedError

    if H.print_attn_layout:
        width = H.attn_cols_to_print
        for i in range(min(width, n_bctx)):
            print(' '.join([str(x) for x in layout[i, 0:width].astype(np.int32)]))
        pdb.set_trace()

    bst = bs.BlocksparseTransformer(
        layout, block_size=blocksize,
        mask_callback=get_callback(attn_mode), heads=n_heads)

    H.bst_cache[key] = bst
    return bst