in train.py [0:0]
def dense_attention(x, n_heads, attn_mode, use_cache=False, train=False, pdrop=None):
nx = x.shape[-1].value
n_state = int(nx * H.qk_ratio)
if n_state % n_heads != 0:
raise ValueError('nx must be divisible by head state')
h = norm("attn_input", x)
qh = h[:, -1:, :] if use_cache else h
q = linear('q_proj', qh, n_state, std=np.sqrt(H.qk_w / nx))
k = linear('k_proj', h, n_state, std=np.sqrt(H.qk_w / nx))
v = linear('v_proj', h, nx, std=np.sqrt(H.v_w / nx))
q = split_heads("q_split", q, n_heads)
k = split_heads("k_split", k, n_heads)
v = split_heads("v_split", v, n_heads)
if use_cache:
if attn_mode not in ['a', 'b', 'c', 'bT']:
raise NotImplementedError
mask = None
if attn_mode == 'b':
k = k[:, :, -H.local_attn_ctx:, :]
v = v[:, :, -H.local_attn_ctx:, :]
elif attn_mode in ['c', 'bT']:
k = k[:, :, ::-H.local_attn_ctx, :][:, :, ::-1, :]
v = v[:, :, ::-H.local_attn_ctx, :][:, :, ::-1, :]
else:
n_timesteps = k.shape[2].value
mask = get_dense_attn_mask(n_timesteps, attn_mode)
if H.float16:
# These products can overflow, so we do it in float32.
k = bs.float_cast(k, dtype=tf.float32)
q = bs.float_cast(q, dtype=tf.float32)
v = bs.float_cast(v, dtype=tf.float32)
w = tf.matmul(q, k, transpose_b=True)
w = bs.masked_softmax(w, mask=mask, scale=1.0 / np.sqrt(q.shape[-1].value))
a = tf.matmul(w, v)
a = merge_heads("merge_attn", a)
if H.float16:
a = bs.float_cast(a, dtype=tf.float16)
return post_attention(x, a, use_cache=use_cache, train=train, pdrop=pdrop)