def prepare_encoder_output_state()

in ludwig/decoders/sequence_decoders.py [0:0]


    def prepare_encoder_output_state(self, inputs):

        if 'encoder_output_state' in inputs:
            encoder_output_state = inputs['encoder_output_state']
        else:
            hidden = inputs['hidden']
            if len(hidden.shape) == 3:  # encoder_output is a sequence
                # reduce_sequence returns a [b, h]
                encoder_output_state = self.reduce_sequence(hidden)
            elif len(hidden.shape) == 2:
                # this returns a [b, h]
                encoder_output_state = hidden
            else:
                raise ValueError("Only works for 1d or 2d encoder_output")

        # now we have to deal with the fact that the state needs to be a list
        # in case of lstm or a tensor otherwise
        if (self.cell_type == 'lstm' and
                isinstance(encoder_output_state, list)):
            if len(encoder_output_state) == 2:
                # this maybe a unidirectionsl lstm or a bidirectional gru / rnn
                # there is no way to tell
                # If it is a unidirectional lstm, pass will work fine
                # if it is bidirectional gru / rnn, the output of one of
                # the directions will be treated as the inital c of the lstm
                # which is weird and may lead to poor performance
                # todo try to find a way to distinguish among these two cases
                pass
            elif len(encoder_output_state) == 4:
                # the encoder was a bidirectional lstm
                # a good strategy is to average the 2 h and the 2 c vectors
                encoder_output_state = [
                    average(
                        [encoder_output_state[0], encoder_output_state[2]]
                    ),
                    average(
                        [encoder_output_state[1], encoder_output_state[3]]
                    )
                ]
            else:
                # no idea how lists of length different than 2 or 4
                # might have been originated, we can either rise an ValueError
                # or deal with it averaging everything
                # raise ValueError(
                #     "encoder_output_state has length different than 2 or 4. "
                #     "Please doublecheck your encoder"
                # )
                average_state = average(encoder_output_state)
                encoder_output_state = [average_state, average_state]

        elif (self.cell_type == 'lstm' and
              not isinstance(encoder_output_state, list)):
            encoder_output_state = [encoder_output_state, encoder_output_state]

        elif (self.cell_type != 'lstm' and
              isinstance(encoder_output_state, list)):
            # here we have a couple options,
            # either reuse part of the input encoder state,
            # or just use its output
            if len(encoder_output_state) == 2:
                # using h and ignoring c
                encoder_output_state = encoder_output_state[0]
            elif len(encoder_output_state) == 4:
                # using average of hs and ignoring cs
                encoder_output_state + average(
                    [encoder_output_state[0], encoder_output_state[2]]
                )
            else:
                # no idea how lists of length different than 2 or 4
                # might have been originated, we can either rise an ValueError
                # or deal with it averaging everything
                # raise ValueError(
                #     "encoder_output_state has length different than 2 or 4. "
                #     "Please doublecheck your encoder"
                # )
                encoder_output_state = average(encoder_output_state)
            # this returns a [b, h]
            # decoder_input_state = reduce_sequence(eo, self.reduce_input)

        elif (self.cell_type != 'lstm' and
              not isinstance(encoder_output_state, list)):
            # do nothing, we are good
            pass

        # at this point decoder_input_state is either a [b,h]
        # or a list([b,h], [b,h]) if the decoder cell is an lstm
        # but h may not be the same as the decoder state size,
        # so we may need to project
        if isinstance(encoder_output_state, list):
            for i in range(len(encoder_output_state)):
                if (encoder_output_state[i].shape[1] !=
                        self.state_size):
                    encoder_output_state[i] = self.project(
                        encoder_output_state[i]
                    )
        else:
            if encoder_output_state.shape[1] != self.state_size:
                encoder_output_state = self.project(
                    encoder_output_state
                )

        return encoder_output_state