def __init__()

in lingvo/jax/layers/transformers.py [0:0]


  def __init__(self, params):
    # This will create a decoder (LM) with key transformer.
    super().__init__(params)
    p = self.params

    def set_model_dims_and_packing(stacked_transformer_tpl, model_dims,
                                   packed_input):
      if stacked_transformer_tpl.cls == StackedTransformer:
        assert (stacked_transformer_tpl.model_dims == 0 or
                stacked_transformer_tpl.model_dims == model_dims)
        stacked_transformer_tpl.model_dims = model_dims
        stacked_transformer_tpl.packed_input = packed_input
      elif stacked_transformer_tpl.cls == StackedTransformerRepeated:
        assert (stacked_transformer_tpl.block.model_dims == 0 or
                stacked_transformer_tpl.block.model_dims == model_dims)
        stacked_transformer_tpl.block.model_dims = model_dims
        stacked_transformer_tpl.block.packed_input = packed_input
      else:
        assert False, f'{stacked_transformer_tpl.cls} not supported.'

    # Create position embeddings.
    if p.position_emb_tpl is not None:
      assert (p.position_emb_tpl.embedding_dims == 0 or
              p.position_emb_tpl.embedding_dims == p.model_dims)
      p.position_emb_tpl.embedding_dims = p.model_dims
      self.create_child('position_emb', p.position_emb_tpl)

    # Create the encoder.
    if p.encoder_stacked_transformer_tpl is None:
      raise ValueError(
          'Encoder stack must be specified for TransformerEncoderDecoder.')

    # Use the user specified StackedTransformer for the encoder, assuming
    # everything is set up appropriately.
    encoder_params = p.encoder_stacked_transformer_tpl.Copy()
    set_model_dims_and_packing(encoder_params, p.model_dims, p.packed_input)
    # Assert that encoder is not masked.
    if encoder_params.cls == StackedTransformer:
      mask_self_attention = encoder_params.mask_self_attention
    elif encoder_params.cls == StackedTransformerRepeated:
      mask_self_attention = encoder_params.block.mask_self_attention
    else:
      raise ValueError('Unknown encoder stack.')

    if mask_self_attention:
      raise ValueError(
          'Encoder attention should be un-masked in TransformerEncoderDecoder.')
    self.create_child('encoder', encoder_params)

    # Optional separate embedding layer for source ids.
    if p.encoder_embedding_tpl is not None:
      encoder_embedding_params = p.encoder_embedding_tpl.Copy()
      assert (encoder_embedding_params.embedding_dims == 0 or
              encoder_embedding_params.embedding_dims == p.model_dims)
      encoder_embedding_params.embedding_dims = p.model_dims
      self.create_child('encoder_embedding_lookup', encoder_embedding_params)

    # Optional NGrammer layer for the encoder.
    # Paper: https://openreview.net/forum?id=GxjCYmQAody
    if p.encoder_ngrammer_tpl is not None:
      self.create_child('encoder_ngrammer', p.encoder_ngrammer_tpl)

    # Encoder output layer norm.
    encoder_ln_params = normalizations.LayerNorm.Params().Set(
        input_dims=p.model_dims)
    self.create_child('encoder_ln', encoder_ln_params)

    # Create the decoder.
    if p.decoder_stacked_transformer_tpl is None:
      raise ValueError(
          'Decoder stack must be specified for TransformerEncoderDecoder.')

    # Use the user specified StackedTransformer for the decoder, assuming
    # everything is set up appropriately.
    decoder_params = p.decoder_stacked_transformer_tpl
    set_model_dims_and_packing(decoder_params, p.model_dims, p.packed_input)
    # Assert that decoder is masked.
    # Assert that encoder is not masked.
    if decoder_params.cls == StackedTransformer:
      mask_self_attention = decoder_params.mask_self_attention
    elif decoder_params.cls == StackedTransformerRepeated:
      mask_self_attention = decoder_params.block.mask_self_attention
    else:
      raise ValueError('Unknown decoder stack.')

    if not mask_self_attention:
      raise ValueError(
          'Decoder attention should be masked in TransformerEncoderDecoder.')
    self.create_child('decoder', decoder_params)

    # Optional separate embedding layer for target ids.
    if p.decoder_embedding_tpl is not None:
      decoder_embedding_params = p.decoder_embedding_tpl.Copy()
      assert (decoder_embedding_params.embedding_dims == 0 or
              decoder_embedding_params.embedding_dims == p.model_dims)
      decoder_embedding_params.embedding_dims = p.model_dims
      self.create_child('decoder_embedding_lookup', decoder_embedding_params)

    # Optional NGrammer layer for the decoder.
    # Paper: https://openreview.net/forum?id=GxjCYmQAody
    if p.decoder_ngrammer_tpl is not None:
      self.create_child('decoder_ngrammer', p.decoder_ngrammer_tpl)

    # Decoder output layer norm.
    decoder_ln_params = normalizations.LayerNorm.Params().Set(
        input_dims=p.model_dims)
    self.create_child('decoder_ln', decoder_ln_params)

    # Final softmax.
    softmax_params = p.softmax_tpl.Copy()
    assert (softmax_params.input_dims == 0 or
            softmax_params.input_dims == p.model_dims)
    softmax_params.input_dims = p.model_dims
    self.create_child('softmax', softmax_params)