in ludwig/decoders/sequence_decoders.py [0:0]
def build_decoder_initial_state(self, batch_size, encoder_state, dtype):
decoder_initial_state = self.decoder_rnncell.get_initial_state(
batch_size=batch_size,
dtype=dtype)
# handle situation where encoder and decoder are different cell_types
# and to account for inconsistent wrapping for encoder state w/in lists
if self.cell_type == 'lstm' and not isinstance(encoder_state, list):
encoder_state = [encoder_state, encoder_state]
elif self.cell_type != 'lstm' and isinstance(encoder_state, list):
encoder_state = encoder_state[0]
if self.attention_mechanism is not None:
decoder_initial_state = decoder_initial_state.clone(
cell_state=encoder_state)
else:
if not isinstance(encoder_state, list):
decoder_initial_state = [encoder_state]
else:
decoder_initial_state = encoder_state
return decoder_initial_state