in nmt/gnmt_model.py [0:0]
def _build_encoder(self, hparams):
"""Build a GNMT encoder."""
if hparams.encoder_type == "uni" or hparams.encoder_type == "bi":
return super(GNMTModel, self)._build_encoder(hparams)
if hparams.encoder_type != "gnmt":
raise ValueError("Unknown encoder_type %s" % hparams.encoder_type)
# Build GNMT encoder.
num_bi_layers = 1
num_uni_layers = self.num_encoder_layers - num_bi_layers
utils.print_out("# Build a GNMT encoder")
utils.print_out(" num_bi_layers = %d" % num_bi_layers)
utils.print_out(" num_uni_layers = %d" % num_uni_layers)
iterator = self.iterator
source = iterator.source
if self.time_major:
source = tf.transpose(source)
with tf.variable_scope("encoder") as scope:
dtype = scope.dtype
self.encoder_emb_inp = self.encoder_emb_lookup_fn(
self.embedding_encoder, source)
# Execute _build_bidirectional_rnn from Model class
bi_encoder_outputs, bi_encoder_state = self._build_bidirectional_rnn(
inputs=self.encoder_emb_inp,
sequence_length=iterator.source_sequence_length,
dtype=dtype,
hparams=hparams,
num_bi_layers=num_bi_layers,
num_bi_residual_layers=0, # no residual connection
)
# Build unidirectional layers
if self.extract_encoder_layers:
encoder_state, encoder_outputs = self._build_individual_encoder_layers(
bi_encoder_outputs, num_uni_layers, dtype, hparams)
else:
encoder_state, encoder_outputs = self._build_all_encoder_layers(
bi_encoder_outputs, num_uni_layers, dtype, hparams)
# Pass all encoder states to the decoder
# except the first bi-directional layer
encoder_state = (bi_encoder_state[1],) + (
(encoder_state,) if num_uni_layers == 1 else encoder_state)
return encoder_outputs, encoder_state