in tensorflow_addons/seq2seq/attention_wrapper.py [0:0]
def call(self, inputs, mask=None, setup_memory=False, **kwargs):
"""Setup the memory or query the attention.
There are two case here, one for setup memory, and the second is query
the attention score. `setup_memory` is the flag to indicate which mode
it is. The input list will be treated differently based on that flag.
Args:
inputs: a list of tensor that could either be `query` and `state`, or
`memory` and `memory_sequence_length`.
`query` is the tensor of dtype matching `memory` and shape
`[batch_size, query_depth]`.
`state` is the tensor of dtype matching `memory` and shape
`[batch_size, alignments_size]`. (`alignments_size` is memory's
`max_time`).
`memory` is the memory to query; usually the output of an RNN
encoder. The tensor should be shaped `[batch_size, max_time, ...]`.
`memory_sequence_length` (optional) is the sequence lengths for the
batch entries in memory. If provided, the memory tensor rows are
masked with zeros for values past the respective sequence lengths.
mask: optional bool tensor with shape `[batch, max_time]` for the
mask of memory. If it is not None, the corresponding item of the
memory should be filtered out during calculation.
setup_memory: boolean, whether the input is for setting up memory, or
query attention.
**kwargs: Dict, other keyword arguments for the call method.
Returns:
Either processed memory or attention score, based on `setup_memory`.
"""
if setup_memory:
if isinstance(inputs, list):
if len(inputs) not in (1, 2):
raise ValueError(
"Expect inputs to have 1 or 2 tensors, got %d" % len(inputs)
)
memory = inputs[0]
memory_sequence_length = inputs[1] if len(inputs) == 2 else None
memory_mask = mask
else:
memory, memory_sequence_length = inputs, None
memory_mask = mask
self.setup_memory(memory, memory_sequence_length, memory_mask)
# We force the self.built to false here since only memory is,
# initialized but the real query/state has not been call() yet. The
# layer should be build and call again.
self.built = False
# Return the processed memory in order to create the Keras
# connectivity data for it.
return self.values
else:
if not self._memory_initialized:
raise ValueError(
"Cannot query the attention before the setup of memory"
)
if len(inputs) not in (2, 3):
raise ValueError(
"Expect the inputs to have query, state, and optional "
"processed memory, got %d items" % len(inputs)
)
# Ignore the rest of the inputs and only care about the query and
# state
query, state = inputs[0], inputs[1]
return self._calculate_attention(query, state)