in nmt/attention_model.py [0:0]
def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
source_sequence_length):
"""Build a RNN cell with attention mechanism that can be used by decoder."""
# No Attention
if not self.has_attention:
return super(AttentionModel, self)._build_decoder_cell(
hparams, encoder_outputs, encoder_state, source_sequence_length)
elif hparams.attention_architecture != "standard":
raise ValueError(
"Unknown attention architecture %s" % hparams.attention_architecture)
num_units = hparams.num_units
num_layers = self.num_decoder_layers
num_residual_layers = self.num_decoder_residual_layers
infer_mode = hparams.infer_mode
dtype = tf.float32
# Ensure memory is batch-major
if self.time_major:
memory = tf.transpose(encoder_outputs, [1, 0, 2])
else:
memory = encoder_outputs
if (self.mode == tf.contrib.learn.ModeKeys.INFER and
infer_mode == "beam_search"):
memory, source_sequence_length, encoder_state, batch_size = (
self._prepare_beam_search_decoder_inputs(
hparams.beam_width, memory, source_sequence_length,
encoder_state))
else:
batch_size = self.batch_size
# Attention
attention_mechanism = self.attention_mechanism_fn(
hparams.attention, num_units, memory, source_sequence_length, self.mode)
cell = model_helper.create_rnn_cell(
unit_type=hparams.unit_type,
num_units=num_units,
num_layers=num_layers,
num_residual_layers=num_residual_layers,
forget_bias=hparams.forget_bias,
dropout=hparams.dropout,
num_gpus=self.num_gpus,
mode=self.mode,
single_cell_fn=self.single_cell_fn)
# Only generate alignment in greedy INFER mode.
alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and
infer_mode != "beam_search")
cell = tf.contrib.seq2seq.AttentionWrapper(
cell,
attention_mechanism,
attention_layer_size=num_units,
alignment_history=alignment_history,
output_attention=hparams.output_attention,
name="attention")
# TODO(thangluong): do we need num_layers, num_gpus?
cell = tf.contrib.rnn.DeviceWrapper(cell,
model_helper.get_device_str(
num_layers - 1, self.num_gpus))
if hparams.pass_hidden_state:
decoder_initial_state = cell.zero_state(batch_size, dtype).clone(
cell_state=encoder_state)
else:
decoder_initial_state = cell.zero_state(batch_size, dtype)
return cell, decoder_initial_state