def predict()

in LearningMachine.py [0:0]


    def predict(self, predict_data_path, output_path, file_columns, predict_fields=['prediction']):
        """ prediction

        Args:
            predict_data_path:
            predict_fields: default: only prediction. For classification and regression tasks, prediction_confidence is also supported.

        Returns:

        """
        if predict_data_path is None:
            predict_data_path = self.conf.predict_data_path

        predict_data, predict_length, _ = self.problem.encode(predict_data_path, file_columns, self.conf.input_types,
            self.conf.file_with_col_header,self.conf.object_inputs, None, min_sentence_len=self.conf.min_sentence_len,
            extra_feature=self.conf.extra_feature,max_lengths=self.conf.max_lengths, fixed_lengths=self.conf.fixed_lengths,
            file_format='tsv', show_progress=True if self.conf.mode == 'normal' else False, 
            cpu_num_workers=self.conf.cpu_num_workers, chunk_size=self.conf.chunk_size)

        logging.info("Starting predict ...")
        self.model.eval()
        with torch.no_grad():
            data_batches, length_batches, _ = \
                get_batches(self.problem, predict_data, predict_length, None, self.conf.batch_size_total,
                    self.conf.input_types, None, permutate=False, transform_tensor=True)

            streaming_recoder = StreamingRecorder(predict_fields)

            fin = open(predict_data_path, 'r', encoding='utf-8')
            with open_and_move(output_path) as fout:
                if self.conf.file_with_col_header:
                    title_line = fin.readline()
                    fout.write(title_line)
                key_random = random.choice(list(length_batches[0].keys()).remove('target') if 'target' in list(length_batches[0].keys()) else list(length_batches[0].keys()))
                if self.conf.mode == 'normal':
                    progress = tqdm(range(len(data_batches)))
                elif self.conf.mode == 'philly':
                    progress = range(len(data_batches))
                for i in progress:
                    # batch_size_actual = target_batches[i].size(0)
                    param_list, inputs_desc, length_desc = transform_params2tensors(data_batches[i], length_batches[i])
                    logits = self.model(inputs_desc, length_desc, *param_list)

                    logits_softmax = {}
                    if isinstance(self.model, nn.DataParallel):
                        for tmp_output_layer_id in self.model.module.output_layer_id:
                            if isinstance(self.model.module.layers[tmp_output_layer_id], Linear) and \
                                    (not self.model.module.layers[tmp_output_layer_id].layer_conf.last_hidden_softmax):
                                logits_softmax[tmp_output_layer_id] = nn.functional.softmax(
                                    logits[tmp_output_layer_id], dim=-1)
                            else:
                                logits_softmax[tmp_output_layer_id] = logits[tmp_output_layer_id]
                    else:
                        for tmp_output_layer_id in self.model.output_layer_id:
                            if isinstance(self.model.layers[tmp_output_layer_id], Linear) and \
                                    (not self.model.layers[tmp_output_layer_id].layer_conf.last_hidden_softmax):
                                logits_softmax[tmp_output_layer_id] = nn.functional.softmax(
                                    logits[tmp_output_layer_id], dim=-1)
                            else:
                                logits_softmax[tmp_output_layer_id] = logits[tmp_output_layer_id]

                    if ProblemTypes[self.problem.problem_type] == ProblemTypes.sequence_tagging:
                        logits = list(logits.values())[0]
                        if isinstance(get_layer_class(self.model, tmp_output_layer_id), CRF):
                            forward_score, scores, masks, tag_seq, transitions, layer_conf = logits
                            prediction_indices = tag_seq.cpu().numpy()
                        else:
                            logits_softmax = list(logits_softmax.values())[0]
                            # Transform output shapes for metric evaluation
                            # for seq_tag_f1 metric
                            prediction_indices = logits_softmax.data.max(2)[1].cpu().numpy()  # [batch_size, seq_len]
                        prediction_batch = self.problem.decode(prediction_indices, length_batches[i][key_random].numpy())
                        for prediction_sample in prediction_batch:
                            streaming_recoder.record('prediction', " ".join(prediction_sample))
                    elif ProblemTypes[self.problem.problem_type] == ProblemTypes.classification:
                        logits = list(logits.values())[0]
                        logits_softmax = list(logits_softmax.values())[0]
                        prediction_indices = logits_softmax.data.max(1)[1].cpu().numpy()

                        for field in predict_fields:
                            if field == 'prediction':
                                streaming_recoder.record(field,
                                    self.problem.decode(prediction_indices, length_batches[i][key_random].numpy()))
                            elif field == 'confidence':
                                prediction_scores = logits_softmax.cpu().data.numpy()
                                for prediction_score, prediction_idx in zip(prediction_scores, prediction_indices):
                                    streaming_recoder.record(field, prediction_score[prediction_idx])
                            elif field.startswith('confidence') and field.find('@') != -1:
                                label_specified = field.split('@')[1]
                                label_specified_idx = self.problem.output_dict.id(label_specified)
                                confidence_specified = torch.index_select(logits_softmax.cpu(), 1,
                                        torch.tensor([label_specified_idx], dtype=torch.long)).squeeze(1)
                                streaming_recoder.record(field, confidence_specified.data.numpy())
                    elif ProblemTypes[self.problem.problem_type] == ProblemTypes.regression:
                        logits = list(logits.values())[0]
                        # logits_softmax is unuseful for regression task!
                        logits_softmax = list(logits_softmax.values())[0]
                        logits_flat = logits.squeeze(1)
                        prediction_scores = logits_flat.detach().cpu().numpy()
                        streaming_recoder.record_one_row([prediction_scores])
                    elif ProblemTypes[self.problem.problem_type] == ProblemTypes.mrc:
                        for key, value in logits.items():
                            logits[key] = value.squeeze()
                        for key, value in logits_softmax.items():
                            logits_softmax[key] = value.squeeze()
                        passage_identify = None
                        for type_key in data_batches[i].keys():
                            if 'p' in type_key.lower():
                                passage_identify = type_key
                                break
                        if not passage_identify:
                            raise Exception('MRC task need passage information.')
                        prediction = self.problem.decode(logits_softmax, lengths=length_batches[i][passage_identify],
                                                         batch_data=data_batches[i][passage_identify])
                        streaming_recoder.record_one_row([prediction])

                    logits_len = len(list(logits.values())[0]) \
                        if ProblemTypes[self.problem.problem_type] == ProblemTypes.mrc else len(logits)
                    for sample_idx in range(logits_len):
                        sample = fin.readline().rstrip()
                        fout.write("%s\t%s\n" % (sample,
                            "\t".join([str(streaming_recoder.get(field)[sample_idx]) for field in predict_fields])))
                    streaming_recoder.clear_records()

                    del logits, logits_softmax

        fin.close()