in LearningMachine.py [0:0]
def predict(self, predict_data_path, output_path, file_columns, predict_fields=['prediction']):
""" prediction
Args:
predict_data_path:
predict_fields: default: only prediction. For classification and regression tasks, prediction_confidence is also supported.
Returns:
"""
if predict_data_path is None:
predict_data_path = self.conf.predict_data_path
predict_data, predict_length, _ = self.problem.encode(predict_data_path, file_columns, self.conf.input_types,
self.conf.file_with_col_header,self.conf.object_inputs, None, min_sentence_len=self.conf.min_sentence_len,
extra_feature=self.conf.extra_feature,max_lengths=self.conf.max_lengths, fixed_lengths=self.conf.fixed_lengths,
file_format='tsv', show_progress=True if self.conf.mode == 'normal' else False,
cpu_num_workers=self.conf.cpu_num_workers, chunk_size=self.conf.chunk_size)
logging.info("Starting predict ...")
self.model.eval()
with torch.no_grad():
data_batches, length_batches, _ = \
get_batches(self.problem, predict_data, predict_length, None, self.conf.batch_size_total,
self.conf.input_types, None, permutate=False, transform_tensor=True)
streaming_recoder = StreamingRecorder(predict_fields)
fin = open(predict_data_path, 'r', encoding='utf-8')
with open_and_move(output_path) as fout:
if self.conf.file_with_col_header:
title_line = fin.readline()
fout.write(title_line)
key_random = random.choice(list(length_batches[0].keys()).remove('target') if 'target' in list(length_batches[0].keys()) else list(length_batches[0].keys()))
if self.conf.mode == 'normal':
progress = tqdm(range(len(data_batches)))
elif self.conf.mode == 'philly':
progress = range(len(data_batches))
for i in progress:
# batch_size_actual = target_batches[i].size(0)
param_list, inputs_desc, length_desc = transform_params2tensors(data_batches[i], length_batches[i])
logits = self.model(inputs_desc, length_desc, *param_list)
logits_softmax = {}
if isinstance(self.model, nn.DataParallel):
for tmp_output_layer_id in self.model.module.output_layer_id:
if isinstance(self.model.module.layers[tmp_output_layer_id], Linear) and \
(not self.model.module.layers[tmp_output_layer_id].layer_conf.last_hidden_softmax):
logits_softmax[tmp_output_layer_id] = nn.functional.softmax(
logits[tmp_output_layer_id], dim=-1)
else:
logits_softmax[tmp_output_layer_id] = logits[tmp_output_layer_id]
else:
for tmp_output_layer_id in self.model.output_layer_id:
if isinstance(self.model.layers[tmp_output_layer_id], Linear) and \
(not self.model.layers[tmp_output_layer_id].layer_conf.last_hidden_softmax):
logits_softmax[tmp_output_layer_id] = nn.functional.softmax(
logits[tmp_output_layer_id], dim=-1)
else:
logits_softmax[tmp_output_layer_id] = logits[tmp_output_layer_id]
if ProblemTypes[self.problem.problem_type] == ProblemTypes.sequence_tagging:
logits = list(logits.values())[0]
if isinstance(get_layer_class(self.model, tmp_output_layer_id), CRF):
forward_score, scores, masks, tag_seq, transitions, layer_conf = logits
prediction_indices = tag_seq.cpu().numpy()
else:
logits_softmax = list(logits_softmax.values())[0]
# Transform output shapes for metric evaluation
# for seq_tag_f1 metric
prediction_indices = logits_softmax.data.max(2)[1].cpu().numpy() # [batch_size, seq_len]
prediction_batch = self.problem.decode(prediction_indices, length_batches[i][key_random].numpy())
for prediction_sample in prediction_batch:
streaming_recoder.record('prediction', " ".join(prediction_sample))
elif ProblemTypes[self.problem.problem_type] == ProblemTypes.classification:
logits = list(logits.values())[0]
logits_softmax = list(logits_softmax.values())[0]
prediction_indices = logits_softmax.data.max(1)[1].cpu().numpy()
for field in predict_fields:
if field == 'prediction':
streaming_recoder.record(field,
self.problem.decode(prediction_indices, length_batches[i][key_random].numpy()))
elif field == 'confidence':
prediction_scores = logits_softmax.cpu().data.numpy()
for prediction_score, prediction_idx in zip(prediction_scores, prediction_indices):
streaming_recoder.record(field, prediction_score[prediction_idx])
elif field.startswith('confidence') and field.find('@') != -1:
label_specified = field.split('@')[1]
label_specified_idx = self.problem.output_dict.id(label_specified)
confidence_specified = torch.index_select(logits_softmax.cpu(), 1,
torch.tensor([label_specified_idx], dtype=torch.long)).squeeze(1)
streaming_recoder.record(field, confidence_specified.data.numpy())
elif ProblemTypes[self.problem.problem_type] == ProblemTypes.regression:
logits = list(logits.values())[0]
# logits_softmax is unuseful for regression task!
logits_softmax = list(logits_softmax.values())[0]
logits_flat = logits.squeeze(1)
prediction_scores = logits_flat.detach().cpu().numpy()
streaming_recoder.record_one_row([prediction_scores])
elif ProblemTypes[self.problem.problem_type] == ProblemTypes.mrc:
for key, value in logits.items():
logits[key] = value.squeeze()
for key, value in logits_softmax.items():
logits_softmax[key] = value.squeeze()
passage_identify = None
for type_key in data_batches[i].keys():
if 'p' in type_key.lower():
passage_identify = type_key
break
if not passage_identify:
raise Exception('MRC task need passage information.')
prediction = self.problem.decode(logits_softmax, lengths=length_batches[i][passage_identify],
batch_data=data_batches[i][passage_identify])
streaming_recoder.record_one_row([prediction])
logits_len = len(list(logits.values())[0]) \
if ProblemTypes[self.problem.problem_type] == ProblemTypes.mrc else len(logits)
for sample_idx in range(logits_len):
sample = fin.readline().rstrip()
fout.write("%s\t%s\n" % (sample,
"\t".join([str(streaming_recoder.get(field)[sample_idx]) for field in predict_fields])))
streaming_recoder.clear_records()
del logits, logits_softmax
fin.close()