in question_generation_model.py [0:0]
def generate_batch(self, batch_size, id_question_dict, id_imagefeat_dict, id_keyword_dict=None, shuffle=True):
"""
Generator function resposible for generating batches during training
:param batch_size: Batch size to be used
:param id_question_dict: Dict with image id as key and question list as value
:param id_imagefeat_dict: Dict with image id as key and image features as value
:param id_keyword_dict: Dict with image id as key and keyword as value
:param shuffle: If shuffle is true, randomly shuffle the image ids
:return:
"""
while True:
image_ids = list(id_question_dict.keys())
image_ids = [id for id in image_ids if id in id_imagefeat_dict]
num_samples = len(image_ids)
if shuffle:
random.shuffle(image_ids)
# Get index to start each batch:
# [0, batch_size, 2*batch_size, ..., max multiple of batch_size <= num_samples]
for offset in range(0, num_samples, batch_size):
X1 = list()
X2 = list()
X3 = list()
Y = list()
bert_label = list()
# Get the samples you'll use in this batch
batch_samples = image_ids[offset:offset + batch_size]
for image_id in batch_samples:
try:
image_feature = id_imagefeat_dict[image_id]
if image_feature is None:
self.logger.debug('Image has no feature %s' % image_id)
continue
except:
self.logger.error('Image %s not found' % image_id)
continue
try:
keyword = id_keyword_dict[image_id]
except:
keyword = " "
x1 = image_feature
image_questions = id_question_dict[image_id]
for image_question in image_questions:
token_seq = [self.datasets.word_to_idx[word] for word in image_question.split(' ') if
word in self.datasets.word_to_idx]
for i in range(1, len(token_seq)):
in_seq, out_seq = token_seq[:i], token_seq[i]
bert_label.append(out_seq)
y = to_categorical([out_seq], num_classes=self.vocab_size)[0]
Y.append(y)
X1.append(x1)
if self.datasets.use_keyword:
x2_glove = pad_sequences([in_seq], maxlen=self.datasets.max_question_len, padding='post')[0]
X2.append(x2_glove)
keyword_token_seq = [self.datasets.word_to_idx[word] for word in keyword.split(' ') if
word in self.datasets.word_to_idx]
keyword_tokens = pad_sequences([keyword_token_seq], maxlen=self.datasets.max_keyword_len, padding='post')[0]
X3.append(keyword_tokens)
elif 'glove' in self.datasets.embedding_file:
x2_glove = pad_sequences([in_seq], maxlen=self.datasets.max_question_len, padding='post')[0]
X2.append(x2_glove)
# Input format for ELMO and Bert
elif 'elmo' in self.datasets.embedding_file:
x2_elmo = ' '.join([self.datasets.idx_to_word[idx] for idx in in_seq[1:]])
x2_elmo = self.cleanText(x2_elmo)
X2.append([x2_elmo])
# Input format for ELMO and Bert
elif 'bert' in self.datasets.embedding_file:
x2_bert = ' '.join([self.datasets.idx_to_word[idx] for idx in in_seq[1:]])
x2_bert = self.cleanText(x2_bert)
X2.append([x2_bert])
if self.datasets.use_keyword:
print ('Chhavi', array(X1).shape, array(X2).shape, array(X3).shape, array(Y).shape)
yield [[array(X1), array(X2), array(X3)], array(Y)]
# Bert input is slightly different from the rest
elif 'bert' in self.datasets.embedding_file:
input_ids, input_masks, segment_ids, labels = preprocess_bert_input(X2,
bert_label,
self.datasets.max_question_len,
self.tokenizer,
self.vocab_size)
yield [[array(X1), array(input_ids), array(input_masks), array(segment_ids)], array(labels)]
else:
print ('Ashi', array(X1).shape, array(X2).shape, array(Y).shape)
yield [[array(X1), array(X2)], array(Y)]