in train.py [0:0]
def get_dense_attn_mask(n, attn_mode):
'''a is dense attention, b is local attention (previous k),
bT is strided (every kth element), implemented as a transpose'''
key = f'{n}-{attn_mode}'
dense_mask = H.dense_mask_cache.get(key)
if dense_mask is not None:
return dense_mask
if attn_mode == 'a_all':
b = tf.ones([n, n], dtype=tf.float32)
elif attn_mode == 'a':
b = tf.matrix_band_part(tf.ones([n, n]), -1, 0)
elif attn_mode == 'b':
bandwidth = H.local_attn_ctx
ctx = tf.minimum(n - 1, bandwidth - 1)
b = tf.matrix_band_part(tf.ones([n, n]), ctx, 0)
elif attn_mode in ['c', 'bT']:
stride = H.local_attn_ctx
x = tf.reshape(tf.range(n, dtype=tf.int32), [n, 1])
y = tf.transpose(x)
z = tf.zeros([n, n], dtype=tf.int32)
q = z + x
k = z + y
c1 = q >= k
c2 = tf.equal(tf.floormod(q - k, stride), 0)
c3 = tf.logical_and(c1, c2)
b = tf.cast(c3, tf.float32)
else:
raise ValueError('Not yet implemented')
b = tf.reshape(b, [1, 1, n, n])
H.dense_mask_cache[key] = b
return b