in ludwig/decoders/sequence_decoders.py [0:0]
def prepare_encoder_output_state(self, inputs):
if 'encoder_output_state' in inputs:
encoder_output_state = inputs['encoder_output_state']
else:
hidden = inputs['hidden']
if len(hidden.shape) == 3: # encoder_output is a sequence
# reduce_sequence returns a [b, h]
encoder_output_state = self.reduce_sequence(hidden)
elif len(hidden.shape) == 2:
# this returns a [b, h]
encoder_output_state = hidden
else:
raise ValueError("Only works for 1d or 2d encoder_output")
# now we have to deal with the fact that the state needs to be a list
# in case of lstm or a tensor otherwise
if (self.cell_type == 'lstm' and
isinstance(encoder_output_state, list)):
if len(encoder_output_state) == 2:
# this maybe a unidirectionsl lstm or a bidirectional gru / rnn
# there is no way to tell
# If it is a unidirectional lstm, pass will work fine
# if it is bidirectional gru / rnn, the output of one of
# the directions will be treated as the inital c of the lstm
# which is weird and may lead to poor performance
# todo try to find a way to distinguish among these two cases
pass
elif len(encoder_output_state) == 4:
# the encoder was a bidirectional lstm
# a good strategy is to average the 2 h and the 2 c vectors
encoder_output_state = [
average(
[encoder_output_state[0], encoder_output_state[2]]
),
average(
[encoder_output_state[1], encoder_output_state[3]]
)
]
else:
# no idea how lists of length different than 2 or 4
# might have been originated, we can either rise an ValueError
# or deal with it averaging everything
# raise ValueError(
# "encoder_output_state has length different than 2 or 4. "
# "Please doublecheck your encoder"
# )
average_state = average(encoder_output_state)
encoder_output_state = [average_state, average_state]
elif (self.cell_type == 'lstm' and
not isinstance(encoder_output_state, list)):
encoder_output_state = [encoder_output_state, encoder_output_state]
elif (self.cell_type != 'lstm' and
isinstance(encoder_output_state, list)):
# here we have a couple options,
# either reuse part of the input encoder state,
# or just use its output
if len(encoder_output_state) == 2:
# using h and ignoring c
encoder_output_state = encoder_output_state[0]
elif len(encoder_output_state) == 4:
# using average of hs and ignoring cs
encoder_output_state + average(
[encoder_output_state[0], encoder_output_state[2]]
)
else:
# no idea how lists of length different than 2 or 4
# might have been originated, we can either rise an ValueError
# or deal with it averaging everything
# raise ValueError(
# "encoder_output_state has length different than 2 or 4. "
# "Please doublecheck your encoder"
# )
encoder_output_state = average(encoder_output_state)
# this returns a [b, h]
# decoder_input_state = reduce_sequence(eo, self.reduce_input)
elif (self.cell_type != 'lstm' and
not isinstance(encoder_output_state, list)):
# do nothing, we are good
pass
# at this point decoder_input_state is either a [b,h]
# or a list([b,h], [b,h]) if the decoder cell is an lstm
# but h may not be the same as the decoder state size,
# so we may need to project
if isinstance(encoder_output_state, list):
for i in range(len(encoder_output_state)):
if (encoder_output_state[i].shape[1] !=
self.state_size):
encoder_output_state[i] = self.project(
encoder_output_state[i]
)
else:
if encoder_output_state.shape[1] != self.state_size:
encoder_output_state = self.project(
encoder_output_state
)
return encoder_output_state