in tensorflow_addons/seq2seq/attention_wrapper.py [0:0]
def call(self, inputs, state, **kwargs):
"""Perform a step of attention-wrapped RNN.
- Step 1: Mix the `inputs` and previous step's `attention` output via
`cell_input_fn`.
- Step 2: Call the wrapped `cell` with this input and its previous
state.
- Step 3: Score the cell's output with `attention_mechanism`.
- Step 4: Calculate the alignments by passing the score through the
`normalizer`.
- Step 5: Calculate the context vector as the inner product between the
alignments and the attention_mechanism's values (memory).
- Step 6: Calculate the attention output by concatenating the cell
output and context through the attention layer (a linear layer with
`attention_layer_size` outputs).
Args:
inputs: (Possibly nested tuple of) Tensor, the input at this time
step.
state: An instance of `tfa.seq2seq.AttentionWrapperState` containing
tensors from the previous time step.
**kwargs: Dict, other keyword arguments for the cell call method.
Returns:
A tuple `(attention_or_cell_output, next_state)`, where:
- `attention_or_cell_output` depending on `output_attention`.
- `next_state` is an instance of `tfa.seq2seq.AttentionWrapperState`
containing the state calculated at this time step.
Raises:
TypeError: If `state` is not an instance of `tfa.seq2seq.AttentionWrapperState`.
"""
if not isinstance(state, AttentionWrapperState):
try:
state = AttentionWrapperState(*state)
except TypeError:
raise TypeError(
"Expected state to be instance of AttentionWrapperState or "
"values that can construct AttentionWrapperState. "
"Received type %s instead." % type(state)
)
# Step 1: Calculate the true inputs to the cell based on the
# previous attention value.
cell_inputs = self._cell_input_fn(inputs, state.attention)
cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state, **kwargs)
next_cell_state = tf.nest.pack_sequence_as(
cell_state, tf.nest.flatten(next_cell_state)
)
cell_batch_size = cell_output.shape[0] or tf.shape(cell_output)[0]
error_message = (
"When applying AttentionWrapper %s: " % self.name
+ "Non-matching batch sizes between the memory "
"(encoder output) and the query (decoder output). Are you using "
"the BeamSearchDecoder? You may need to tile your memory input "
"via the tfa.seq2seq.tile_batch function with argument "
"multiple=beam_width."
)
with tf.control_dependencies(
self._batch_size_checks(cell_batch_size, error_message)
): # pylint: disable=bad-continuation
cell_output = tf.identity(cell_output, name="checked_cell_output")
if self._is_multi:
previous_attention_state = state.attention_state
previous_alignment_history = state.alignment_history
else:
previous_attention_state = [state.attention_state]
previous_alignment_history = [state.alignment_history]
all_alignments = []
all_attentions = []
all_attention_states = []
maybe_all_histories = []
for i, attention_mechanism in enumerate(self._attention_mechanisms):
attention, alignments, next_attention_state = self._attention_fn(
attention_mechanism,
cell_output,
previous_attention_state[i],
self._attention_layers[i] if self._attention_layers else None,
)
alignment_history = (
previous_alignment_history[i].write(
previous_alignment_history[i].size(), alignments
)
if self._alignment_history
else ()
)
all_attention_states.append(next_attention_state)
all_alignments.append(alignments)
all_attentions.append(attention)
maybe_all_histories.append(alignment_history)
attention = tf.concat(all_attentions, 1)
next_state = AttentionWrapperState(
cell_state=next_cell_state,
attention=attention,
attention_state=self._item_or_tuple(all_attention_states),
alignments=self._item_or_tuple(all_alignments),
alignment_history=self._item_or_tuple(maybe_all_histories),
)
if self._output_attention:
return attention, next_state
else:
return cell_output, next_state