def train()

in LearningMachine.py [0:0]


    def train(self, optimizer, loss_fn):
        self.model.train()
        logging.info("="*100 + '\n' + "*"*15 + 'Prepare data for training' + "*"*15)

        valid_data, valid_length, valid_target = self.problem.encode(self.conf.valid_data_path, self.conf.file_columns,
            self.conf.input_types, self.conf.file_with_col_header, self.conf.object_inputs, self.conf.answer_column_name, max_lengths=self.conf.max_lengths,
            min_sentence_len = self.conf.min_sentence_len, extra_feature = self.conf.extra_feature,fixed_lengths=self.conf.fixed_lengths, file_format='tsv',
            show_progress=True if self.conf.mode == 'normal' else False, cpu_num_workers=self.conf.cpu_num_workers, chunk_size=self.conf.chunk_size)

        if self.conf.test_data_path is not None:
            test_data, test_length, test_target = self.problem.encode(self.conf.test_data_path, self.conf.file_columns, 
            self.conf.input_types, self.conf.file_with_col_header, self.conf.object_inputs, self.conf.answer_column_name, max_lengths=self.conf.max_lengths,
            min_sentence_len = self.conf.min_sentence_len, extra_feature = self.conf.extra_feature,fixed_lengths=self.conf.fixed_lengths, file_format='tsv', 
            show_progress=True if self.conf.mode == 'normal' else False, cpu_num_workers=self.conf.cpu_num_workers, chunk_size=self.conf.chunk_size)

        stop_training = False
        epoch = 1
        best_result = None
        show_result_cnt = 0
        lr_scheduler = LRScheduler(optimizer, self.conf.lr_decay, self.conf.minimum_lr, self.conf.epoch_start_lr_decay)

        if ProblemTypes[self.problem.problem_type] == ProblemTypes.classification:
            streaming_recoder = StreamingRecorder(['prediction', 'pred_scores', 'pred_scores_all', 'target'])
        elif ProblemTypes[self.problem.problem_type] == ProblemTypes.sequence_tagging:
            streaming_recoder = StreamingRecorder(['prediction', 'pred_scores', 'target'])
        elif ProblemTypes[self.problem.problem_type] == ProblemTypes.regression:
            streaming_recoder = StreamingRecorder(['prediction', 'target'])
        elif ProblemTypes[self.problem.problem_type] == ProblemTypes.mrc:
            streaming_recoder = StreamingRecorder(['prediction', 'answer_text'])

        logging.info("=" * 100 + '\n' + "*" * 15 + 'Start training' + "*" * 15)
        while not stop_training and epoch <= self.conf.max_epoch:
            logging.info('Training: Epoch ' + str(epoch))
            train_data_generator = self._get_training_data_generator()
            part_index = 1
            for train_data, train_length, train_target in train_data_generator:
                logging.debug('Training: Epoch %s Part %s'%(epoch, part_index))
                part_index += 1
                data_batches, length_batches, target_batches = \
                    get_batches(self.problem, train_data, train_length, train_target, self.conf.batch_size_total,
                        self.conf.input_types, None, permutate=True, transform_tensor=True)

                whole_batch_num = len(target_batches)
                valid_batch_num = min(self.conf.steps_per_validation, whole_batch_num)
                small_batch_num = whole_batch_num
                valid_batch_num_show = valid_batch_num
                batch_num_to_show_results = self.conf.batch_num_to_show_results
                if torch.cuda.device_count() > 1:
                    batch_num_to_show_results *= torch.cuda.device_count() # total batch num overall all the gpus to log 
                    small_batch_num *= torch.cuda.device_count()       # total batch num over all the gpus
                    valid_batch_num_show *= torch.cuda.device_count()      # total batch num over all the gpus to do validation
                
                streaming_recoder.clear_records()
                all_costs = []

                logging.info('There are %d batches during current period; validation are conducted every %d batch' % (small_batch_num, valid_batch_num_show))

                if self.conf.mode == 'normal':
                    progress = tqdm(range(len(target_batches)))
                elif self.conf.mode == 'philly':
                    progress = range(len(target_batches))
                for i in progress:
                    # the result shape: for classification: [batch_size, # of classes]; for sequence tagging: [batch_size, seq_len, # of tags]
                    param_list, inputs_desc, length_desc = transform_params2tensors(data_batches[i], length_batches[i])
                    logits = self.model(inputs_desc, length_desc, *param_list)

                    logits_softmax = {}
                    if isinstance(self.model, nn.DataParallel):
                        for tmp_output_layer_id in self.model.module.output_layer_id:
                            if isinstance(self.model.module.layers[tmp_output_layer_id], Linear) and \
                                    (not self.model.module.layers[tmp_output_layer_id].layer_conf.last_hidden_softmax):
                                logits_softmax[tmp_output_layer_id] = nn.functional.softmax(
                                    logits[tmp_output_layer_id], dim=-1)
                            elif isinstance(get_layer_class(self.model, tmp_output_layer_id), CRF):
                                pass
                            else:
                                logits_softmax[tmp_output_layer_id] = logits[tmp_output_layer_id]
                    else:
                        for tmp_output_layer_id in self.model.output_layer_id:
                            if isinstance(self.model.layers[tmp_output_layer_id], Linear) and \
                                    (not self.model.layers[tmp_output_layer_id].layer_conf.last_hidden_softmax):
                                logits_softmax[tmp_output_layer_id] = nn.functional.softmax(
                                    logits[tmp_output_layer_id], dim=-1)
                            elif isinstance(get_layer_class(self.model, tmp_output_layer_id), CRF):
                                pass
                            else:
                                logits_softmax[tmp_output_layer_id] = logits[tmp_output_layer_id]

                    # check the output
                    if ProblemTypes[self.problem.problem_type] == ProblemTypes.classification:
                        logits = list(logits.values())[0]
                        logits_softmax = list(logits_softmax.values())[0]
                        assert len(logits_softmax.shape) == 2, 'The dimension of your output is %s, but we need [batch_size*GPUs, class num]' % (str(list(logits_softmax.shape)))
                        assert logits_softmax.shape[1] == self.problem.output_target_num(), 'The dimension of your output layer %d is inconsistent with your type number %d!' % (logits_softmax.shape[1], self.problem.output_target_num())
                        # for auc metric
                        prediction_scores = logits_softmax[:, self.conf.pos_label].cpu().data.numpy()
                        if self.evaluator.has_auc_type_specific:
                            prediction_scores_all = logits_softmax.cpu().data.numpy()
                        else:
                            prediction_scores_all = None
                    elif ProblemTypes[self.problem.problem_type] == ProblemTypes.sequence_tagging:
                        logits = list(logits.values())[0]
                        if not isinstance(get_layer_class(self.model, tmp_output_layer_id), CRF):
                            logits_softmax = list(logits_softmax.values())[0]
                            assert len(logits_softmax.shape) == 3, 'The dimension of your output is %s, but we need [batch_size*GPUs, sequence length, representation dim]' % (str(list(logits_softmax.shape)), )
                        prediction_scores = None
                        prediction_scores_all = None
                    elif ProblemTypes[self.problem.problem_type] == ProblemTypes.regression:
                        logits = list(logits.values())[0]
                        logits_softmax = list(logits_softmax.values())[0]
                        assert len(logits_softmax.shape) == 2 and logits_softmax.shape[1] == 1, 'The dimension of your output is %s, but we need [batch_size*GPUs, 1]' % (str(list(logits_softmax.shape)))
                        prediction_scores = None
                        prediction_scores_all = None
                    elif ProblemTypes[self.problem.problem_type] == ProblemTypes.mrc:
                        for single_value in logits_softmax.values():
                            assert len(single_value.shape) == 3, 'The dimension of your output is %s, but we need [batch_size*GPUs, sequence_len, 1]' % (str(list(single_value.shape)))
                        prediction_scores = None
                        prediction_scores_all = None

                    logits_flat = dict()
                    if ProblemTypes[self.problem.problem_type] == ProblemTypes.sequence_tagging:
                        # Transform output shapes for metric evaluation
                        # for seq_tag_f1 metric
                        if isinstance(get_layer_class(self.model, tmp_output_layer_id), CRF):
                            forward_score, scores, masks, tag_seq, transitions, layer_conf = logits
                            prediction_indices = tag_seq.cpu().numpy()
                            streaming_recoder.record_one_row([self.problem.decode(prediction_indices, length_batches[i]['target'][self.conf.answer_column_name[0]].numpy()),
                                                            prediction_scores, self.problem.decode(
                                    target_batches[i][self.conf.answer_column_name[0]],
                                    length_batches[i]['target'][self.conf.answer_column_name[0]].numpy())], keep_dim=False)

                        else:
                            prediction_indices = logits_softmax.data.max(2)[1].cpu().numpy()    # [batch_size, seq_len]
                            # pytorch's CrossEntropyLoss only support this
                            logits_flat[self.conf.output_layer_id[0]] = logits.view(-1, logits.size(2))  # [batch_size * seq_len, # of tags]
                            streaming_recoder.record_one_row([self.problem.decode(prediction_indices, length_batches[i]['target'][self.conf.answer_column_name[0]].numpy()),
                                                            prediction_scores, self.problem.decode(
                                    target_batches[i][self.conf.answer_column_name[0]],
                                    length_batches[i]['target'][self.conf.answer_column_name[0]].numpy())], keep_dim=False)

                            target_batches[i][self.conf.answer_column_name[0]] = target_batches[i][
                                self.conf.answer_column_name[0]].reshape(-1)

                    elif ProblemTypes[self.problem.problem_type] == ProblemTypes.classification:
                        prediction_indices = logits_softmax.detach().max(1)[1].cpu().numpy()
                        # Should not decode!
                        streaming_recoder.record_one_row([prediction_indices, prediction_scores, prediction_scores_all, target_batches[i][self.conf.answer_column_name[0]].numpy()])
                        logits_flat[self.conf.output_layer_id[0]] = logits
                    elif ProblemTypes[self.problem.problem_type] == ProblemTypes.regression:
                        temp_logits_flat = logits.squeeze(1)
                        prediction_scores = temp_logits_flat.detach().cpu().numpy()
                        streaming_recoder.record_one_row([prediction_scores, target_batches[i][self.conf.answer_column_name[0]].numpy()])
                        logits_flat[self.conf.output_layer_id[0]] = temp_logits_flat
                    elif ProblemTypes[self.problem.problem_type] == ProblemTypes.mrc:
                        for key, value in logits.items():
                            logits[key] = value.squeeze()
                        for key, value in logits_softmax.items():
                            logits_softmax[key] = value.squeeze()
                        passage_identify = None
                        for type_key in data_batches[i].keys():
                            if 'p' in type_key.lower():
                                passage_identify = type_key
                                break
                        if not passage_identify:
                            raise Exception('MRC task need passage information.')
                        prediction = self.problem.decode(logits_softmax, lengths=length_batches[i][passage_identify],
                                                        batch_data=data_batches[i][passage_identify])
                        logits_flat = logits
                        mrc_answer_target = None
                        for single_target in target_batches[i]:
                            if isinstance(target_batches[i][single_target][0], str):
                                mrc_answer_target = target_batches[i][single_target]
                        streaming_recoder.record_one_row([prediction, mrc_answer_target])

                    if self.use_gpu:
                        for single_target in self.conf.answer_column_name:
                            if isinstance(target_batches[i][single_target], torch.Tensor):
                                target_batches[i][single_target] = transfer_to_gpu(target_batches[i][single_target])
                    if isinstance(loss_fn.loss_fn[0], CRFLoss):
                        loss = loss_fn.loss_fn[0](forward_score, scores, masks, list(target_batches[i].values())[0], transitions, layer_conf)
                    else:
                        loss = loss_fn(logits_flat, target_batches[i])

                    all_costs.append(loss.item())
                    optimizer.zero_grad()
                    loss.backward()
                    if self.conf.clip_grad_norm_max_norm != -1:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.conf.clip_grad_norm_max_norm)
                        if isinstance(self.model, nn.DataParallel):
                            torch.nn.utils.clip_grad_norm_(self.model.module.layers['embedding'].get_parameters(), self.conf.clip_grad_norm_max_norm)
                        else:
                            torch.nn.utils.clip_grad_norm_(self.model.layers['embedding'].get_parameters(), self.conf.clip_grad_norm_max_norm)
                    optimizer.step()

                    del loss, logits, logits_softmax, logits_flat
                    del prediction_scores
                    if ProblemTypes[self.problem.problem_type] == ProblemTypes.sequence_tagging \
                            or ProblemTypes[self.problem.problem_type] == ProblemTypes.classification:
                        del prediction_indices

                    if show_result_cnt == batch_num_to_show_results:
                        if ProblemTypes[self.problem.problem_type] == ProblemTypes.classification:
                            result = self.evaluator.evaluate(streaming_recoder.get('target'),
                                streaming_recoder.get('prediction'), y_pred_pos_score=streaming_recoder.get('pred_scores'),
                                y_pred_scores_all=streaming_recoder.get('pred_scores_all'), formatting=True)
                        elif ProblemTypes[self.problem.problem_type] == ProblemTypes.sequence_tagging:
                            result = self.evaluator.evaluate(streaming_recoder.get('target'),
                                streaming_recoder.get('prediction'), y_pred_pos_score=streaming_recoder.get('pred_scores'),
                                formatting=True)
                        elif ProblemTypes[self.problem.problem_type] == ProblemTypes.regression:
                            result = self.evaluator.evaluate(streaming_recoder.get('target'),
                                streaming_recoder.get('prediction'), y_pred_pos_score=None, y_pred_scores_all=None, formatting=True)
                        elif ProblemTypes[self.problem.problem_type] == ProblemTypes.mrc:
                            result = self.evaluator.evaluate(streaming_recoder.get('answer_text'), streaming_recoder.get('prediction'),
                                                                y_pred_pos_score=None, y_pred_scores_all=None, formatting=True)

                        if torch.cuda.device_count() > 1:
                            logging.info("Epoch %d batch idx: %d; lr: %f; since last log, loss=%f; %s" % \
                                (epoch, i * torch.cuda.device_count(), lr_scheduler.get_lr(), np.mean(all_costs), result))
                        else:
                            logging.info("Epoch %d batch idx: %d; lr: %f; since last log, loss=%f; %s" % \
                                (epoch, i, lr_scheduler.get_lr(), np.mean(all_costs), result))
                        show_result_cnt = 0
                        # The loss and other metrics printed during a training epoch are just the result of part of the training data.
                        all_costs = []
                        streaming_recoder.clear_records()

                    if (i != 0 and i % valid_batch_num == 0) or i == len(target_batches) - 1:
                        torch.cuda.empty_cache()    # actually useless
                        logging.info('Valid & Test : Epoch ' + str(epoch))
                        new_result = self.evaluate(valid_data, valid_length, valid_target,
                            self.conf.input_types, self.evaluator, loss_fn, pad_ids=None, cur_best_result=best_result,
                            model_save_path=self.conf.model_save_path, phase="valid", epoch=epoch)
                        renew_flag = best_result != new_result
                        best_result = new_result

                        if renew_flag and self.conf.test_data_path is not None:
                            self.evaluate(test_data, test_length, test_target,
                                self.conf.input_types, self.evaluator, loss_fn, pad_ids=None, phase="test", epoch=epoch)
                        self.model.train()
                    show_result_cnt += 1

                del data_batches, length_batches, target_batches
            lr_scheduler.step()
            epoch += 1