def _ProjectAndSplitHeads()

in trax/layers/research/efficient_attention.py [0:0]


def _ProjectAndSplitHeads(  # pylint: disable=invalid-name
    d_model,
    n_heads,
    use_bias,
    num_weights=2,
    sparsity=16,
    length_kernel_size=3,
    weights_format='sparse',
    rotary_position_emb=False,
    mode='train'):
  """Creates the QK and V activations from input."""
  # There can be either two or three weights:
  # two - qk and v or three - q, k, v
  # If there are three, we want to average q and k and use that.

  # Weights can also be in 'heads' major format - (n_heads, d_model, d_head)
  # this is used by efficient_attention.LSHSelfAttention and
  # efficient_attention.SelfAttention

  # Or they can be in 'model' major format - (d_model, d_model), which is what
  # tl._attention/CausalAttention etc use -- so use this format if we pretrain a
  # model trained with those and finetuning with PureLSHSelfAttention.

  assert weights_format in ('heads', 'model', 'sparse')

  # When an earlier model was trained with 3 separate weights for Q, K, V
  # projections with tl._attention/tl._causalAttention etc.
  if weights_format == 'model' and num_weights == 3:
    return cb.Serial(
        # Create the raw Q, K, V projections.
        cb.Branch(
            core.Dense(d_model, use_bias=use_bias),
            core.Dense(d_model, use_bias=use_bias),
            core.Dense(d_model, use_bias=use_bias)),  # q, k, v
        # Optionally, rotate Q and K vectors if rotary embeddings are used.
        cb.Parallel(rotary_pe.Rotate(), rotary_pe.Rotate(), None)
        if rotary_position_emb else [],
        # Average Q and K into one single QK tensor.
        core.Fn('QKAvg', lambda x, y: (x + y) / 2.0, n_out=1),  # qk,   v
        # Split heads and combine with batch dimension to get two tensors of
        # (batch * n_heads, seq_len, d_head) shape.
        cb.Parallel(
            attention.SplitIntoHeads(n_heads),
            attention.SplitIntoHeads(n_heads))  # qk,   v
    )

  if weights_format == 'sparse' and num_weights == 3:
    d_module = d_model // sparsity
    # This layer matches sparsity.MultiplicativeConvCausalAttention,
    # see there for more explanation.
    # TODO(lukaszkaiser): unify code so that we don't duplicate so much.
    return cb.Serial(
        cb.Select([0, 0]),  # duplicate activations
        sp.FactoredDense(sparsity, d_model, d_model),
        cb.Select([0, 0, 0]),  # use for q, k, v
        cb.Parallel(
            [sp.LocallyConvDense(sparsity, d_module, mode=mode,
                                 kernel_size=3,
                                 length_kernel_size=length_kernel_size),
             attention.SplitIntoHeads(n_heads)],
            [sp.LocallyConvDense(sparsity, d_module, mode=mode,
                                 kernel_size=3,
                                 length_kernel_size=length_kernel_size),
             attention.SplitIntoHeads(n_heads)],
            [cb.Select([0], n_in=2),
             sp.LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1,
                                 length_kernel_size=length_kernel_size),
             attention.SplitIntoHeads(n_heads)],
        ),
        core.Fn('QKAvg', lambda x, y: (x + y) / 2.0, n_out=1),
    )

  if weights_format == 'sparse' and num_weights == 2:
    d_module = d_model // sparsity
    # This layer matches sparsity.MultiplicativeConvCausalAttention,
    # see there for more explanation.
    # TODO(lukaszkaiser): unify code so that we don't duplicate so much.
    return cb.Serial(
        cb.Select([0, 0]),  # pre-qkv, pre-v-for-concat
        sp.FactoredDense(sparsity, d_model, d_model),  # shared q k
        cb.Select([0, 0]),  # pre-qk, pre-v, pre-v-for-concat
        sp.LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3,
                            length_kernel_size=length_kernel_size),
        attention.SplitIntoHeads(n_heads),
        cb.Parallel(
            [],
            [cb.Select([0], n_in=2),
             sp.LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1,
                                 length_kernel_size=length_kernel_size),
             attention.SplitIntoHeads(n_heads)],
        )
    )

  # We want to train from scratch and have only two weights, w_qk and w_v.
  if weights_format == 'model' and num_weights == 2:
    return cb.Branch(
        [
            core.Dense(d_model, use_bias=use_bias),
            rotary_pe.Rotate() if rotary_position_emb else [],
            attention.SplitIntoHeads(n_heads)
        ],
        [
            core.Dense(d_model, use_bias=use_bias),
            attention.SplitIntoHeads(n_heads)
        ],
    )

  assert weights_format == 'head'

  raise NotImplementedError('TODO(afrozm): Implement this when we want to use '
                            'checkpoints trained with LSHSelfAttention or '
                            'SelfAttention')