in nmt/gnmt_model.py [0:0]
def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
source_sequence_length):
"""Build a RNN cell with GNMT attention architecture."""
# Standard attention
if not self.is_gnmt_attention:
return super(GNMTModel, self)._build_decoder_cell(
hparams, encoder_outputs, encoder_state, source_sequence_length)
# GNMT attention
attention_option = hparams.attention
attention_architecture = hparams.attention_architecture
num_units = hparams.num_units
infer_mode = hparams.infer_mode
dtype = tf.float32
if self.time_major:
memory = tf.transpose(encoder_outputs, [1, 0, 2])
else:
memory = encoder_outputs
if (self.mode == tf.contrib.learn.ModeKeys.INFER and
infer_mode == "beam_search"):
memory, source_sequence_length, encoder_state, batch_size = (
self._prepare_beam_search_decoder_inputs(
hparams.beam_width, memory, source_sequence_length,
encoder_state))
else:
batch_size = self.batch_size
attention_mechanism = self.attention_mechanism_fn(
attention_option, num_units, memory, source_sequence_length, self.mode)
cell_list = model_helper._cell_list( # pylint: disable=protected-access
unit_type=hparams.unit_type,
num_units=num_units,
num_layers=self.num_decoder_layers,
num_residual_layers=self.num_decoder_residual_layers,
forget_bias=hparams.forget_bias,
dropout=hparams.dropout,
num_gpus=self.num_gpus,
mode=self.mode,
single_cell_fn=self.single_cell_fn,
residual_fn=gnmt_residual_fn
)
# Only wrap the bottom layer with the attention mechanism.
attention_cell = cell_list.pop(0)
# Only generate alignment in greedy INFER mode.
alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and
infer_mode != "beam_search")
attention_cell = tf.contrib.seq2seq.AttentionWrapper(
attention_cell,
attention_mechanism,
attention_layer_size=None, # don't use attention layer.
output_attention=False,
alignment_history=alignment_history,
name="attention")
if attention_architecture == "gnmt":
cell = GNMTAttentionMultiCell(
attention_cell, cell_list)
elif attention_architecture == "gnmt_v2":
cell = GNMTAttentionMultiCell(
attention_cell, cell_list, use_new_attention=True)
else:
raise ValueError(
"Unknown attention_architecture %s" % attention_architecture)
if hparams.pass_hidden_state:
decoder_initial_state = tuple(
zs.clone(cell_state=es)
if isinstance(zs, tf.contrib.seq2seq.AttentionWrapperState) else es
for zs, es in zip(
cell.zero_state(batch_size, dtype), encoder_state))
else:
decoder_initial_state = cell.zero_state(batch_size, dtype)
return cell, decoder_initial_state