in ocr/utils/encoder_decoder.py [0:0]
def forward(self, step_input, states, mask=None): #pylint: disable=arguments-differ, missing-docstring
input_shape = step_input.shape
mem_mask = None
# If it is in testing, transform input tensor to a tensor with shape NTC
# Otherwise remove the None in states.
if len(input_shape) == 2:
if self._encoder_valid_length is not None:
has_last_embeds = len(states) == 3
else:
has_last_embeds = len(states) == 2
if has_last_embeds:
last_embeds = states[0]
step_input = mx.nd.concat(last_embeds,
mx.nd.expand_dims(step_input, axis=1),
dim=1)
states = states[1:]
else:
step_input = mx.nd.expand_dims(step_input, axis=1)
elif states[0] is None:
states = states[1:]
has_mem_mask = (len(states) == 2)
if has_mem_mask:
_, mem_mask = states
augmented_mem_mask = mx.nd.expand_dims(mem_mask, axis=1)\
.broadcast_axes(axis=1, size=step_input.shape[1])
states[-1] = augmented_mem_mask
if mask is None:
length_array = mx.nd.arange(step_input.shape[1], ctx=step_input.context)
mask = mx.nd.broadcast_lesser_equal(
length_array.reshape((1, -1)),
length_array.reshape((-1, 1)))
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0),
axis=0, size=step_input.shape[0])
steps = mx.nd.arange(step_input.shape[1], ctx=step_input.context)
states.append(steps)
step_output, step_additional_outputs = \
super(TransformerDecoder, self).forward(step_input * math.sqrt(step_input.shape[-1]), #pylint: disable=too-many-function-args
states, mask)
states = states[:-1]
if has_mem_mask:
states[-1] = mem_mask
new_states = [step_input] + states
# If it is in testing, only output the last one
if len(input_shape) == 2:
step_output = step_output[:, -1, :]
return step_output, new_states, step_additional_outputs