def get_callback()

in train.py [0:0]


def get_callback(attn_mode):
    def cb(blk_shape, head_idx, qry_idx, key_idx, blk_idx):
        mask = np.ones(blk_shape, dtype=np.bool)
        qdim, kdim = blk_shape
        assert qdim == kdim
        if attn_mode in ['a_all', 'b_all', 'bT_all']:
            return mask
        if qry_idx == key_idx:
            for q in range(qdim):
                mask[q, q + 1:] = 0
        if attn_mode in ['a', 'bT', 'b0']:
            return mask
        if attn_mode == 'b':
            bandwidth = H.local_attn_ctx
            # convert group indices to absolute indices and mask
            # according to that
            q_pos = blk_shape[0] * qry_idx
            k_pos = blk_shape[1] * key_idx
            for q in range(qdim):
                q_ = q + q_pos
                maxw = max(-1, q_ - k_pos - bandwidth)
                mask[q, :maxw + 1] = 0
                if qry_idx == key_idx:
                    mask[q, q + 1:] = 0
            if H.print_attn_layout:
                for i in range(qdim):
                    print(' '.join([str(x) for x in mask[i, 0:kdim].astype(np.int32)]))
                print(qry_idx, key_idx)
                pdb.set_trace()
            return mask
        raise ValueError
    return cb