def get_dense_attn_mask()

in train.py [0:0]


def get_dense_attn_mask(n, attn_mode):
    '''a is dense attention, b is local attention (previous k),
    bT is strided (every kth element), implemented as a transpose'''
    key = f'{n}-{attn_mode}'
    dense_mask = H.dense_mask_cache.get(key)
    if dense_mask is not None:
        return dense_mask

    if attn_mode == 'a_all':
        b = tf.ones([n, n], dtype=tf.float32)
    elif attn_mode == 'a':
        b = tf.matrix_band_part(tf.ones([n, n]), -1, 0)
    elif attn_mode == 'b':
        bandwidth = H.local_attn_ctx
        ctx = tf.minimum(n - 1, bandwidth - 1)
        b = tf.matrix_band_part(tf.ones([n, n]), ctx, 0)
    elif attn_mode in ['c', 'bT']:
        stride = H.local_attn_ctx
        x = tf.reshape(tf.range(n, dtype=tf.int32), [n, 1])
        y = tf.transpose(x)
        z = tf.zeros([n, n], dtype=tf.int32)
        q = z + x
        k = z + y
        c1 = q >= k
        c2 = tf.equal(tf.floormod(q - k, stride), 0)
        c3 = tf.logical_and(c1, c2)
        b = tf.cast(c3, tf.float32)
    else:
        raise ValueError('Not yet implemented')
    b = tf.reshape(b, [1, 1, n, n])

    H.dense_mask_cache[key] = b

    return b