def setup_memory()

in tensorflow_addons/seq2seq/attention_wrapper.py [0:0]


    def setup_memory(self, memory, memory_sequence_length=None, memory_mask=None):
        """Pre-process the memory before actually query the memory.

        This should only be called once at the first invocation of `call()`.

        Args:
          memory: The memory to query; usually the output of an RNN encoder.
            This tensor should be shaped `[batch_size, max_time, ...]`.
          memory_sequence_length (optional): Sequence lengths for the batch
            entries in memory. If provided, the memory tensor rows are masked
            with zeros for values past the respective sequence lengths.
          memory_mask: (Optional) The boolean tensor with shape `[batch_size,
            max_time]`. For any value equal to False, the corresponding value
            in memory should be ignored.
        """
        if memory_sequence_length is not None and memory_mask is not None:
            raise ValueError(
                "memory_sequence_length and memory_mask cannot be "
                "used at same time for attention."
            )
        with tf.name_scope(self.name or "BaseAttentionMechanismInit"):
            self.values = _prepare_memory(
                memory,
                memory_sequence_length=memory_sequence_length,
                memory_mask=memory_mask,
                check_inner_dims_defined=self._check_inner_dims_defined,
            )
            # Mark the value as check since the memory and memory mask might not
            # passed from __call__(), which does not have proper keras metadata.
            # TODO(omalleyt12): Remove this hack once the mask the has proper
            # keras history.

            def _mark_checked(tensor):
                tensor._keras_history_checked = True  # pylint: disable=protected-access

            tf.nest.map_structure(_mark_checked, self.values)
            if self.memory_layer is not None:
                self.keys = self.memory_layer(self.values)
            else:
                self.keys = self.values
            self.batch_size = self.keys.shape[0] or tf.shape(self.keys)[0]
            self._alignments_size = self.keys.shape[1] or tf.shape(self.keys)[1]
            if memory_mask is not None or memory_sequence_length is not None:
                unwrapped_probability_fn = self.default_probability_fn

                def _mask_probability_fn(score, prev):
                    return unwrapped_probability_fn(
                        _maybe_mask_score(
                            score,
                            memory_mask=memory_mask,
                            memory_sequence_length=memory_sequence_length,
                            score_mask_value=score.dtype.min,
                        ),
                        prev,
                    )

                self.probability_fn = _mask_probability_fn
        self._memory_initialized = True