in nmt/model.py [0:0]
def _build_encoder_from_sequence(self, hparams, sequence, sequence_length):
"""Build an encoder from a sequence.
Args:
hparams: hyperparameters.
sequence: tensor with input sequence data.
sequence_length: tensor with length of the input sequence.
Returns:
encoder_outputs: RNN encoder outputs.
encoder_state: RNN encoder state.
Raises:
ValueError: if encoder_type is neither "uni" nor "bi".
"""
num_layers = self.num_encoder_layers
num_residual_layers = self.num_encoder_residual_layers
if self.time_major:
sequence = tf.transpose(sequence)
with tf.variable_scope("encoder") as scope:
dtype = scope.dtype
self.encoder_emb_inp = self.encoder_emb_lookup_fn(
self.embedding_encoder, sequence)
# Encoder_outputs: [max_time, batch_size, num_units]
if hparams.encoder_type == "uni":
utils.print_out(" num_layers = %d, num_residual_layers=%d" %
(num_layers, num_residual_layers))
cell = self._build_encoder_cell(hparams, num_layers,
num_residual_layers)
encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
cell,
self.encoder_emb_inp,
dtype=dtype,
sequence_length=sequence_length,
time_major=self.time_major,
swap_memory=True)
elif hparams.encoder_type == "bi":
num_bi_layers = int(num_layers / 2)
num_bi_residual_layers = int(num_residual_layers / 2)
utils.print_out(" num_bi_layers = %d, num_bi_residual_layers=%d" %
(num_bi_layers, num_bi_residual_layers))
encoder_outputs, bi_encoder_state = (
self._build_bidirectional_rnn(
inputs=self.encoder_emb_inp,
sequence_length=sequence_length,
dtype=dtype,
hparams=hparams,
num_bi_layers=num_bi_layers,
num_bi_residual_layers=num_bi_residual_layers))
if num_bi_layers == 1:
encoder_state = bi_encoder_state
else:
# alternatively concat forward and backward states
encoder_state = []
for layer_id in range(num_bi_layers):
encoder_state.append(bi_encoder_state[0][layer_id]) # forward
encoder_state.append(bi_encoder_state[1][layer_id]) # backward
encoder_state = tuple(encoder_state)
else:
raise ValueError("Unknown encoder_type %s" % hparams.encoder_type)
# Use the top layer for now
self.encoder_state_list = [encoder_outputs]
return encoder_outputs, encoder_state