in question_generation_model.py [0:0]
def greedy_search(self, image_input, model, keyword=None):
"""
Decoding strategy of choosing highest scoring token at each step
:param image_input: image input features
:param model: model definition file
:param keyword: Keyword input features
:return:
"""
if 'glove' in self.datasets.embedding_file:
max_seq_len = model.inputs[1].shape[1].value
elif 'elmo' in self.datasets.embedding_file:
max_seq_len = self.datasets.max_question_len
image_input = np.repeat(image_input, axis=0, repeats=2)
elif 'bert' in self.datasets.embedding_file:
max_seq_len = self.datasets.max_question_len
image_input = np.repeat(image_input, axis=0, repeats=2)
self.logger.info('Max len %s' % max_seq_len)
in_text = '<START>'
prob = 0
for i in range(max_seq_len):
sequence = [self.word_to_idx[w] for w in in_text.split(' ') if w in self.word_to_idx]
if self.datasets.use_keyword:
sequence = pad_sequences([sequence], maxlen=max_seq_len, padding='post')
yhat = model.predict([image_input, sequence, keyword])[0]
elif 'glove' in self.datasets.embedding_file:
sequence = pad_sequences([sequence], maxlen=max_seq_len, padding='post')
yhat = model.predict([image_input, sequence])[0]
elif 'elmo' in self.datasets.embedding_file:
sequence = ' '.join([w for w in in_text.split(' ')])
sequence = self.cleanText(sequence)
sequence = np.array([[sequence], [sequence]])
self.logger.debug('Sequence %s shape %s image %s' % (sequence, sequence.shape, image_input.shape))
yhat = model.predict([image_input, sequence])[0]
elif 'bert' in self.datasets.embedding_file:
sequence = ' '.join([w for w in in_text.split(' ')[1:]])
sequence = self.cleanText(sequence)
sequence = [[sequence], [sequence]]
# Preprocess bert
input_ids, input_masks, segment_ids, _ = preprocess_bert_input(sequence, [None] * len(sequence),
self.datasets.max_question_len, self.tokenizer, self.vocab_size)
yhat = model.predict([image_input, input_ids, input_masks, segment_ids])[0]
else:
self.logger.error('Embedding strategy not supported')
exit(-1)
yhat_max = np.argmax(yhat)
prob += yhat[yhat_max]
word = self.datasets.idx_to_word[yhat_max]
in_text += ' ' + word
if word == '<END>':
break
final = in_text.split()
final = final[1:-1]
final = ' '.join(final)
if final not in self.datasets.unique_train_questions:
self.logger.info(
'Unique generated questions not seen in training data: %s' % final)
self.datasets.unique_generated_questions.add(final)
self.datasets.generated_questions.append([final])
self.logger.info('Final greedy candidate: %s' % final)
return {final: prob}