def call()

in tensorflow_addons/seq2seq/attention_wrapper.py [0:0]


    def call(self, inputs, mask=None, setup_memory=False, **kwargs):
        """Setup the memory or query the attention.

        There are two case here, one for setup memory, and the second is query
        the attention score. `setup_memory` is the flag to indicate which mode
        it is. The input list will be treated differently based on that flag.

        Args:
          inputs: a list of tensor that could either be `query` and `state`, or
            `memory` and `memory_sequence_length`.
            `query` is the tensor of dtype matching `memory` and shape
            `[batch_size, query_depth]`.
            `state` is the tensor of dtype matching `memory` and shape
            `[batch_size, alignments_size]`. (`alignments_size` is memory's
            `max_time`).
            `memory` is the memory to query; usually the output of an RNN
            encoder. The tensor should be shaped `[batch_size, max_time, ...]`.
            `memory_sequence_length` (optional) is the sequence lengths for the
             batch entries in memory. If provided, the memory tensor rows are
            masked with zeros for values past the respective sequence lengths.
          mask: optional bool tensor with shape `[batch, max_time]` for the
            mask of memory. If it is not None, the corresponding item of the
            memory should be filtered out during calculation.
          setup_memory: boolean, whether the input is for setting up memory, or
            query attention.
          **kwargs: Dict, other keyword arguments for the call method.
        Returns:
          Either processed memory or attention score, based on `setup_memory`.
        """
        if setup_memory:
            if isinstance(inputs, list):
                if len(inputs) not in (1, 2):
                    raise ValueError(
                        "Expect inputs to have 1 or 2 tensors, got %d" % len(inputs)
                    )
                memory = inputs[0]
                memory_sequence_length = inputs[1] if len(inputs) == 2 else None
                memory_mask = mask
            else:
                memory, memory_sequence_length = inputs, None
                memory_mask = mask
            self.setup_memory(memory, memory_sequence_length, memory_mask)
            # We force the self.built to false here since only memory is,
            # initialized but the real query/state has not been call() yet. The
            # layer should be build and call again.
            self.built = False
            # Return the processed memory in order to create the Keras
            # connectivity data for it.
            return self.values
        else:
            if not self._memory_initialized:
                raise ValueError(
                    "Cannot query the attention before the setup of memory"
                )
            if len(inputs) not in (2, 3):
                raise ValueError(
                    "Expect the inputs to have query, state, and optional "
                    "processed memory, got %d items" % len(inputs)
                )
            # Ignore the rest of the inputs and only care about the query and
            # state
            query, state = inputs[0], inputs[1]
            return self._calculate_attention(query, state)