in trax/layers/research/efficient_attention.py [0:0]
def _ProjectAndSplitHeads( # pylint: disable=invalid-name
d_model,
n_heads,
use_bias,
num_weights=2,
sparsity=16,
length_kernel_size=3,
weights_format='sparse',
rotary_position_emb=False,
mode='train'):
"""Creates the QK and V activations from input."""
# There can be either two or three weights:
# two - qk and v or three - q, k, v
# If there are three, we want to average q and k and use that.
# Weights can also be in 'heads' major format - (n_heads, d_model, d_head)
# this is used by efficient_attention.LSHSelfAttention and
# efficient_attention.SelfAttention
# Or they can be in 'model' major format - (d_model, d_model), which is what
# tl._attention/CausalAttention etc use -- so use this format if we pretrain a
# model trained with those and finetuning with PureLSHSelfAttention.
assert weights_format in ('heads', 'model', 'sparse')
# When an earlier model was trained with 3 separate weights for Q, K, V
# projections with tl._attention/tl._causalAttention etc.
if weights_format == 'model' and num_weights == 3:
return cb.Serial(
# Create the raw Q, K, V projections.
cb.Branch(
core.Dense(d_model, use_bias=use_bias),
core.Dense(d_model, use_bias=use_bias),
core.Dense(d_model, use_bias=use_bias)), # q, k, v
# Optionally, rotate Q and K vectors if rotary embeddings are used.
cb.Parallel(rotary_pe.Rotate(), rotary_pe.Rotate(), None)
if rotary_position_emb else [],
# Average Q and K into one single QK tensor.
core.Fn('QKAvg', lambda x, y: (x + y) / 2.0, n_out=1), # qk, v
# Split heads and combine with batch dimension to get two tensors of
# (batch * n_heads, seq_len, d_head) shape.
cb.Parallel(
attention.SplitIntoHeads(n_heads),
attention.SplitIntoHeads(n_heads)) # qk, v
)
if weights_format == 'sparse' and num_weights == 3:
d_module = d_model // sparsity
# This layer matches sparsity.MultiplicativeConvCausalAttention,
# see there for more explanation.
# TODO(lukaszkaiser): unify code so that we don't duplicate so much.
return cb.Serial(
cb.Select([0, 0]), # duplicate activations
sp.FactoredDense(sparsity, d_model, d_model),
cb.Select([0, 0, 0]), # use for q, k, v
cb.Parallel(
[sp.LocallyConvDense(sparsity, d_module, mode=mode,
kernel_size=3,
length_kernel_size=length_kernel_size),
attention.SplitIntoHeads(n_heads)],
[sp.LocallyConvDense(sparsity, d_module, mode=mode,
kernel_size=3,
length_kernel_size=length_kernel_size),
attention.SplitIntoHeads(n_heads)],
[cb.Select([0], n_in=2),
sp.LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1,
length_kernel_size=length_kernel_size),
attention.SplitIntoHeads(n_heads)],
),
core.Fn('QKAvg', lambda x, y: (x + y) / 2.0, n_out=1),
)
if weights_format == 'sparse' and num_weights == 2:
d_module = d_model // sparsity
# This layer matches sparsity.MultiplicativeConvCausalAttention,
# see there for more explanation.
# TODO(lukaszkaiser): unify code so that we don't duplicate so much.
return cb.Serial(
cb.Select([0, 0]), # pre-qkv, pre-v-for-concat
sp.FactoredDense(sparsity, d_model, d_model), # shared q k
cb.Select([0, 0]), # pre-qk, pre-v, pre-v-for-concat
sp.LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3,
length_kernel_size=length_kernel_size),
attention.SplitIntoHeads(n_heads),
cb.Parallel(
[],
[cb.Select([0], n_in=2),
sp.LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1,
length_kernel_size=length_kernel_size),
attention.SplitIntoHeads(n_heads)],
)
)
# We want to train from scratch and have only two weights, w_qk and w_v.
if weights_format == 'model' and num_weights == 2:
return cb.Branch(
[
core.Dense(d_model, use_bias=use_bias),
rotary_pe.Rotate() if rotary_position_emb else [],
attention.SplitIntoHeads(n_heads)
],
[
core.Dense(d_model, use_bias=use_bias),
attention.SplitIntoHeads(n_heads)
],
)
assert weights_format == 'head'
raise NotImplementedError('TODO(afrozm): Implement this when we want to use '
'checkpoints trained with LSHSelfAttention or '
'SelfAttention')