def interactive()

in LearningMachine.py [0:0]


    def interactive(self, sample, file_columns, predict_fields=['prediction'], predict_mode='batch'):
        """ interactive prediction

         Args:
            file_columns: representation the columns of sample
            predict_mode: interactive|batch(need a predict file)
        """
        predict_data, predict_length, _, _, _ = \
            self.problem.encode_data_list(sample, file_columns, self.conf.input_types, self.conf.object_inputs, None,
                                          self.conf.min_sentence_len, self.conf.extra_feature, self.conf.max_lengths,
                                          self.conf.fixed_lengths, predict_mode=predict_mode)
        if predict_data is None:
            return 'Wrong Case!'
        self.model.eval()
        with torch.no_grad():
            data_batches, length_batches, _ = \
                get_batches(self.problem, predict_data, predict_length, None, 1,
                            self.conf.input_types, None, permutate=False, transform_tensor=True, predict_mode=predict_mode)
            streaming_recoder = StreamingRecorder(predict_fields)

            key_random = random.choice(
                list(length_batches[0].keys()).remove('target') if 'target' in list(length_batches[0].keys()) else
                list(length_batches[0].keys()))
            param_list, inputs_desc, length_desc = transform_params2tensors(data_batches[0], length_batches[0])
            logits = self.model(inputs_desc, length_desc, *param_list)

            logits_softmax = {}
            if isinstance(self.model, nn.DataParallel):
                for tmp_output_layer_id in self.model.module.output_layer_id:
                    if isinstance(self.model.module.layers[tmp_output_layer_id], Linear) and \
                            (not self.model.module.layers[tmp_output_layer_id].layer_conf.last_hidden_softmax):
                        logits_softmax[tmp_output_layer_id] = nn.functional.softmax(
                            logits[tmp_output_layer_id], dim=-1)
                    else:
                        logits_softmax[tmp_output_layer_id] = logits[tmp_output_layer_id]
            else:
                for tmp_output_layer_id in self.model.output_layer_id:
                    if isinstance(self.model.layers[tmp_output_layer_id], Linear) and \
                            (not self.model.layers[tmp_output_layer_id].layer_conf.last_hidden_softmax):
                        logits_softmax[tmp_output_layer_id] = nn.functional.softmax(
                            logits[tmp_output_layer_id], dim=-1)
                    else:
                        logits_softmax[tmp_output_layer_id] = logits[tmp_output_layer_id]

            if ProblemTypes[self.problem.problem_type] == ProblemTypes.sequence_tagging:
                logits = list(logits.values())[0]
                if isinstance(get_layer_class(self.model, tmp_output_layer_id), CRF):
                    forward_score, scores, masks, tag_seq, transitions, layer_conf = logits
                    prediction_indices = tag_seq.cpu().numpy()
                else:
                    logits_softmax = list(logits_softmax.values())[0]
                    # Transform output shapes for metric evaluation
                    # for seq_tag_f1 metric
                    prediction_indices = logits_softmax.data.max(2)[1].cpu().numpy()  # [batch_size, seq_len]
                prediction_batch = self.problem.decode(prediction_indices, length_batches[0][key_random].numpy())
                for prediction_sample in prediction_batch:
                    streaming_recoder.record('prediction', " ".join(prediction_sample))
            elif ProblemTypes[self.problem.problem_type] == ProblemTypes.classification:
                logits = list(logits.values())[0]
                logits_softmax = list(logits_softmax.values())[0]
                prediction_indices = logits_softmax.data.max(1)[1].cpu().numpy()

                for field in predict_fields:
                    if field == 'prediction':
                        streaming_recoder.record(field,
                                                 self.problem.decode(prediction_indices,
                                                                     length_batches[0][key_random].numpy()))
                    elif field == 'confidence':
                        prediction_scores = logits_softmax.cpu().data.numpy()
                        for prediction_score, prediction_idx in zip(prediction_scores, prediction_indices):
                            streaming_recoder.record(field, prediction_score[prediction_idx])
                    elif field.startswith('confidence') and field.find('@') != -1:
                        label_specified = field.split('@')[1]
                        label_specified_idx = self.problem.output_dict.id(label_specified)
                        confidence_specified = torch.index_select(logits_softmax.cpu(), 1, torch.tensor([label_specified_idx], dtype=torch.long)).squeeze(1)
                        streaming_recoder.record(field, confidence_specified.data.numpy())
            elif ProblemTypes[self.problem.problem_type] == ProblemTypes.regression:
                logits = list(logits.values())[0]
                # logits_softmax is unuseful for regression task!
                logits_softmax = list(logits_softmax.values())[0]
                logits_flat = logits.squeeze(1)
                prediction_scores = logits_flat.detach().cpu().numpy()
                streaming_recoder.record_one_row([prediction_scores])
            elif ProblemTypes[self.problem.problem_type] == ProblemTypes.mrc:
                for key, value in logits.items():
                    logits[key] = value.squeeze()
                for key, value in logits_softmax.items():
                    logits_softmax[key] = value.squeeze()
                passage_identify = None
                for type_key in data_batches[0].keys():
                    if 'p' in type_key.lower():
                        passage_identify = type_key
                        break
                if not passage_identify:
                    raise Exception('MRC task need passage information.')
                prediction = self.problem.decode(logits_softmax, lengths=length_batches[0][passage_identify],
                                                 batch_data=data_batches[0][passage_identify])
                streaming_recoder.record_one_row([prediction])

            return "\t".join([str(streaming_recoder.get(field)[0]) for field in predict_fields])