in question_generation_model.py [0:0]
def train_model(self, model, model_dir, epoch, batch_size, decoder_algorithm, beam_size):
"""
Training function
:param model: model definition
:param model_dir: Directory where model is saved
:param epoch: Number of epochs for training
:param batch_size: Batch size
:param decoder_algorithm: Decoder algorithm to be used with model: greedy, simple beam search or diverse beam search
:param beam_size: Beam size to be used for decoding
:return:
"""
#
# Some functions are commented out to increase training speed
#
model_bucket_name = 'experimental_models'
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=2000)
filepath = os.path.join(model_dir, "model_{epoch:02d}.h5")
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=2, save_best_only=True, save_weights_only=False, mode='auto', period=1)
steps = int(len(self.datasets.train_image_id_questions_dict) // batch_size)
val_steps = int(len(self.datasets.dev_image_id_questions_dict) // batch_size)
self.logger.info('Train steps %s' % steps)
self.logger.info('Validation steps %s' % val_steps)
generator = self.generate_batch(batch_size,
self.datasets.train_image_id_questions_dict,
self.datasets.train_image_id_imagefeat_dict,
self.datasets.train_image_id_keyword_dict)
# val_generator = self.generate_batch(batch_size,
# graph,
# self.datasets.dev_image_id_questions_dict,
# self.datasets.dev_image_id_imagefeat_dict,
# self.datasets.dev_image_id_keyword_dict,
# shuffle=False)
# test_generator = self.generate_batch(1,
# graph,
# self.datasets.test_image_id_questions_dict,
# self.datasets.test_image_id_imagefeat_dict,
# shuffle=False,
#
# test=False)
for no_epoch in range(epoch):
self.logger.info('\n'*5)
self.logger.info(no_epoch)
x_train, y_train = next(generator)
history = model.fit(x_train, y_train,
epochs=1,
verbose=2)
# history = model.fit_generator(generator, validation_generator=val_generator,
# steps_per_epoch=steps,
# val_steps = val_steps
# epochs=1,
# verbose=2,
# callbacks=[checkpoint. es])
model_file_name = os.path.join(model_dir, 'model_' + str(no_epoch) + '.h5')
self.logger.info('Model saved at %s' % model_file_name)
model.save(model_file_name)
# Test model on ids sampled from test set
test_epoch_condition = 0
test_img_count_condition = 2
if no_epoch < test_epoch_condition:
test_img_count = 0
while test_img_count < test_img_count_condition:
# sess.run(tensorflow.local_variables_initializer())
# sess.run(tensorflow.global_variables_initializer())
# sess.run(tensorflow.tables_initializer())
# K.set_session(sess)
id = sys_random.choice(list(self.datasets.test_image_id_url_dict.keys()))
test_image_url = self.datasets.test_image_id_url_dict[id]
self.logger.info('\n\n\n\n\nImage url: %s' % test_image_url)
try:
output_questions = self.test_model(test_image_url, model, decoder_algorithm, beam_size)
gt_questions = self.datasets.test_image_id_questions_dict[id]
self.logger.info('GT ---->%s' % gt_questions)
except:
self.logger.error('Error with inference code')
pass
test_img_count += 1
return epoch