def train_model()

in question_generation_model.py [0:0]


    def train_model(self, model, model_dir, epoch, batch_size, decoder_algorithm, beam_size):
        """
        Training function
        :param model: model definition
        :param model_dir: Directory where model is saved
        :param epoch: Number of epochs for training
        :param batch_size: Batch size
        :param decoder_algorithm: Decoder algorithm to be used with model: greedy, simple beam search or diverse beam search
        :param beam_size: Beam size to be used for decoding
        :return:
        """

        #
        # Some functions are commented out to increase training speed
        #

        model_bucket_name = 'experimental_models'
        es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=2000)
        filepath = os.path.join(model_dir, "model_{epoch:02d}.h5")
        checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=2, save_best_only=True, save_weights_only=False, mode='auto', period=1)
        steps = int(len(self.datasets.train_image_id_questions_dict) // batch_size)
        val_steps = int(len(self.datasets.dev_image_id_questions_dict) // batch_size)
        self.logger.info('Train steps %s' % steps)
        self.logger.info('Validation steps %s' % val_steps)
        generator = self.generate_batch(batch_size,
                                        self.datasets.train_image_id_questions_dict,
                                        self.datasets.train_image_id_imagefeat_dict,
                                        self.datasets.train_image_id_keyword_dict)
        # val_generator = self.generate_batch(batch_size,
        #                                     graph,
        #                                     self.datasets.dev_image_id_questions_dict,
        #                                     self.datasets.dev_image_id_imagefeat_dict,
        #                                     self.datasets.dev_image_id_keyword_dict,
        #                                     shuffle=False)
        # test_generator = self.generate_batch(1,
        #                                      graph,
        #                                      self.datasets.test_image_id_questions_dict,
        #                                      self.datasets.test_image_id_imagefeat_dict,
        #                                      shuffle=False,
        #
        #                                      test=False)
        for no_epoch in range(epoch):
            self.logger.info('\n'*5)
            self.logger.info(no_epoch)
            x_train, y_train = next(generator)
            history = model.fit(x_train, y_train,
                                          epochs=1,
                                          verbose=2)
            # history = model.fit_generator(generator, validation_generator=val_generator,
            #                               steps_per_epoch=steps,
            #                               val_steps = val_steps
            #                               epochs=1,
            #                               verbose=2,
            #                               callbacks=[checkpoint. es])

            model_file_name = os.path.join(model_dir, 'model_' + str(no_epoch) + '.h5')
            self.logger.info('Model saved at %s' % model_file_name)
            model.save(model_file_name)

            # Test model on ids sampled from test set
            test_epoch_condition = 0
            test_img_count_condition = 2
            if no_epoch < test_epoch_condition:
                test_img_count = 0

                while test_img_count < test_img_count_condition:

                    # sess.run(tensorflow.local_variables_initializer())
                    # sess.run(tensorflow.global_variables_initializer())
                    # sess.run(tensorflow.tables_initializer())
                    # K.set_session(sess)

                    id = sys_random.choice(list(self.datasets.test_image_id_url_dict.keys()))
                    test_image_url = self.datasets.test_image_id_url_dict[id]
                    self.logger.info('\n\n\n\n\nImage url: %s' % test_image_url)
                    try:
                        output_questions = self.test_model(test_image_url, model, decoder_algorithm, beam_size)
                        gt_questions = self.datasets.test_image_id_questions_dict[id]
                        self.logger.info('GT  ---->%s' % gt_questions)
                    except:
                        self.logger.error('Error with inference code')
                        pass

                    test_img_count += 1

        return epoch