in train.py [0:0]
def get_callback(attn_mode):
def cb(blk_shape, head_idx, qry_idx, key_idx, blk_idx):
mask = np.ones(blk_shape, dtype=np.bool)
qdim, kdim = blk_shape
assert qdim == kdim
if attn_mode in ['a_all', 'b_all', 'bT_all']:
return mask
if qry_idx == key_idx:
for q in range(qdim):
mask[q, q + 1:] = 0
if attn_mode in ['a', 'bT', 'b0']:
return mask
if attn_mode == 'b':
bandwidth = H.local_attn_ctx
# convert group indices to absolute indices and mask
# according to that
q_pos = blk_shape[0] * qry_idx
k_pos = blk_shape[1] * key_idx
for q in range(qdim):
q_ = q + q_pos
maxw = max(-1, q_ - k_pos - bandwidth)
mask[q, :maxw + 1] = 0
if qry_idx == key_idx:
mask[q, q + 1:] = 0
if H.print_attn_layout:
for i in range(qdim):
print(' '.join([str(x) for x in mask[i, 0:kdim].astype(np.int32)]))
print(qry_idx, key_idx)
pdb.set_trace()
return mask
raise ValueError
return cb