def greedy_search()

in question_generation_model.py [0:0]


    def greedy_search(self, image_input, model, keyword=None):
        """
        Decoding strategy of choosing highest scoring token at each step
        :param image_input: image input features
        :param model: model definition file
        :param keyword: Keyword input features
        :return:
        """

        if 'glove' in self.datasets.embedding_file:
            max_seq_len = model.inputs[1].shape[1].value
        elif 'elmo' in self.datasets.embedding_file:
            max_seq_len = self.datasets.max_question_len
            image_input = np.repeat(image_input, axis=0, repeats=2)
        elif 'bert' in self.datasets.embedding_file:
            max_seq_len = self.datasets.max_question_len
            image_input = np.repeat(image_input, axis=0, repeats=2)
        self.logger.info('Max len %s' % max_seq_len)
        in_text = '<START>'
        prob = 0
        for i in range(max_seq_len):
            sequence = [self.word_to_idx[w] for w in in_text.split(' ') if w in self.word_to_idx]
            if self.datasets.use_keyword:
                sequence = pad_sequences([sequence], maxlen=max_seq_len, padding='post')
                yhat = model.predict([image_input, sequence, keyword])[0]
            elif 'glove' in self.datasets.embedding_file:
                sequence = pad_sequences([sequence], maxlen=max_seq_len, padding='post')
                yhat = model.predict([image_input, sequence])[0]
            elif 'elmo' in self.datasets.embedding_file:
                sequence = ' '.join([w for w in in_text.split(' ')])
                sequence = self.cleanText(sequence)
                sequence = np.array([[sequence], [sequence]])
                self.logger.debug('Sequence %s shape %s image %s' % (sequence, sequence.shape, image_input.shape))

                yhat = model.predict([image_input, sequence])[0]
            elif 'bert' in self.datasets.embedding_file:
                sequence = ' '.join([w for w in in_text.split(' ')[1:]])
                sequence = self.cleanText(sequence)
                sequence = [[sequence], [sequence]]
                # Preprocess bert
                input_ids, input_masks, segment_ids, _ = preprocess_bert_input(sequence, [None] * len(sequence),
                                                                               self.datasets.max_question_len, self.tokenizer, self.vocab_size)
                yhat = model.predict([image_input, input_ids, input_masks, segment_ids])[0]
            else:
                self.logger.error('Embedding strategy not supported')
                exit(-1)

            yhat_max = np.argmax(yhat)
            prob += yhat[yhat_max]
            word = self.datasets.idx_to_word[yhat_max]
            in_text += ' ' + word
            if word == '<END>':
                break
        final = in_text.split()
        final = final[1:-1]
        final = ' '.join(final)

        if final not in self.datasets.unique_train_questions:
            self.logger.info(
             'Unique generated questions not seen in training data: %s' % final)
        self.datasets.unique_generated_questions.add(final)
        self.datasets.generated_questions.append([final])

        self.logger.info('Final greedy candidate: %s' % final)
        return {final: prob}