in lingvo/jax/layers/transformers.py [0:0]
def __init__(self, params):
# This will create a decoder (LM) with key transformer.
super().__init__(params)
p = self.params
def set_model_dims_and_packing(stacked_transformer_tpl, model_dims,
packed_input):
if stacked_transformer_tpl.cls == StackedTransformer:
assert (stacked_transformer_tpl.model_dims == 0 or
stacked_transformer_tpl.model_dims == model_dims)
stacked_transformer_tpl.model_dims = model_dims
stacked_transformer_tpl.packed_input = packed_input
elif stacked_transformer_tpl.cls == StackedTransformerRepeated:
assert (stacked_transformer_tpl.block.model_dims == 0 or
stacked_transformer_tpl.block.model_dims == model_dims)
stacked_transformer_tpl.block.model_dims = model_dims
stacked_transformer_tpl.block.packed_input = packed_input
else:
assert False, f'{stacked_transformer_tpl.cls} not supported.'
# Create position embeddings.
if p.position_emb_tpl is not None:
assert (p.position_emb_tpl.embedding_dims == 0 or
p.position_emb_tpl.embedding_dims == p.model_dims)
p.position_emb_tpl.embedding_dims = p.model_dims
self.create_child('position_emb', p.position_emb_tpl)
# Create the encoder.
if p.encoder_stacked_transformer_tpl is None:
raise ValueError(
'Encoder stack must be specified for TransformerEncoderDecoder.')
# Use the user specified StackedTransformer for the encoder, assuming
# everything is set up appropriately.
encoder_params = p.encoder_stacked_transformer_tpl.Copy()
set_model_dims_and_packing(encoder_params, p.model_dims, p.packed_input)
# Assert that encoder is not masked.
if encoder_params.cls == StackedTransformer:
mask_self_attention = encoder_params.mask_self_attention
elif encoder_params.cls == StackedTransformerRepeated:
mask_self_attention = encoder_params.block.mask_self_attention
else:
raise ValueError('Unknown encoder stack.')
if mask_self_attention:
raise ValueError(
'Encoder attention should be un-masked in TransformerEncoderDecoder.')
self.create_child('encoder', encoder_params)
# Optional separate embedding layer for source ids.
if p.encoder_embedding_tpl is not None:
encoder_embedding_params = p.encoder_embedding_tpl.Copy()
assert (encoder_embedding_params.embedding_dims == 0 or
encoder_embedding_params.embedding_dims == p.model_dims)
encoder_embedding_params.embedding_dims = p.model_dims
self.create_child('encoder_embedding_lookup', encoder_embedding_params)
# Optional NGrammer layer for the encoder.
# Paper: https://openreview.net/forum?id=GxjCYmQAody
if p.encoder_ngrammer_tpl is not None:
self.create_child('encoder_ngrammer', p.encoder_ngrammer_tpl)
# Encoder output layer norm.
encoder_ln_params = normalizations.LayerNorm.Params().Set(
input_dims=p.model_dims)
self.create_child('encoder_ln', encoder_ln_params)
# Create the decoder.
if p.decoder_stacked_transformer_tpl is None:
raise ValueError(
'Decoder stack must be specified for TransformerEncoderDecoder.')
# Use the user specified StackedTransformer for the decoder, assuming
# everything is set up appropriately.
decoder_params = p.decoder_stacked_transformer_tpl
set_model_dims_and_packing(decoder_params, p.model_dims, p.packed_input)
# Assert that decoder is masked.
# Assert that encoder is not masked.
if decoder_params.cls == StackedTransformer:
mask_self_attention = decoder_params.mask_self_attention
elif decoder_params.cls == StackedTransformerRepeated:
mask_self_attention = decoder_params.block.mask_self_attention
else:
raise ValueError('Unknown decoder stack.')
if not mask_self_attention:
raise ValueError(
'Decoder attention should be masked in TransformerEncoderDecoder.')
self.create_child('decoder', decoder_params)
# Optional separate embedding layer for target ids.
if p.decoder_embedding_tpl is not None:
decoder_embedding_params = p.decoder_embedding_tpl.Copy()
assert (decoder_embedding_params.embedding_dims == 0 or
decoder_embedding_params.embedding_dims == p.model_dims)
decoder_embedding_params.embedding_dims = p.model_dims
self.create_child('decoder_embedding_lookup', decoder_embedding_params)
# Optional NGrammer layer for the decoder.
# Paper: https://openreview.net/forum?id=GxjCYmQAody
if p.decoder_ngrammer_tpl is not None:
self.create_child('decoder_ngrammer', p.decoder_ngrammer_tpl)
# Decoder output layer norm.
decoder_ln_params = normalizations.LayerNorm.Params().Set(
input_dims=p.model_dims)
self.create_child('decoder_ln', decoder_ln_params)
# Final softmax.
softmax_params = p.softmax_tpl.Copy()
assert (softmax_params.input_dims == 0 or
softmax_params.input_dims == p.model_dims)
softmax_params.input_dims = p.model_dims
self.create_child('softmax', softmax_params)