in tensorflow_addons/seq2seq/attention_wrapper.py [0:0]
def setup_memory(self, memory, memory_sequence_length=None, memory_mask=None):
"""Pre-process the memory before actually query the memory.
This should only be called once at the first invocation of `call()`.
Args:
memory: The memory to query; usually the output of an RNN encoder.
This tensor should be shaped `[batch_size, max_time, ...]`.
memory_sequence_length (optional): Sequence lengths for the batch
entries in memory. If provided, the memory tensor rows are masked
with zeros for values past the respective sequence lengths.
memory_mask: (Optional) The boolean tensor with shape `[batch_size,
max_time]`. For any value equal to False, the corresponding value
in memory should be ignored.
"""
if memory_sequence_length is not None and memory_mask is not None:
raise ValueError(
"memory_sequence_length and memory_mask cannot be "
"used at same time for attention."
)
with tf.name_scope(self.name or "BaseAttentionMechanismInit"):
self.values = _prepare_memory(
memory,
memory_sequence_length=memory_sequence_length,
memory_mask=memory_mask,
check_inner_dims_defined=self._check_inner_dims_defined,
)
# Mark the value as check since the memory and memory mask might not
# passed from __call__(), which does not have proper keras metadata.
# TODO(omalleyt12): Remove this hack once the mask the has proper
# keras history.
def _mark_checked(tensor):
tensor._keras_history_checked = True # pylint: disable=protected-access
tf.nest.map_structure(_mark_checked, self.values)
if self.memory_layer is not None:
self.keys = self.memory_layer(self.values)
else:
self.keys = self.values
self.batch_size = self.keys.shape[0] or tf.shape(self.keys)[0]
self._alignments_size = self.keys.shape[1] or tf.shape(self.keys)[1]
if memory_mask is not None or memory_sequence_length is not None:
unwrapped_probability_fn = self.default_probability_fn
def _mask_probability_fn(score, prev):
return unwrapped_probability_fn(
_maybe_mask_score(
score,
memory_mask=memory_mask,
memory_sequence_length=memory_sequence_length,
score_mask_value=score.dtype.min,
),
prev,
)
self.probability_fn = _mask_probability_fn
self._memory_initialized = True