in src/deep_baselines/run.py [0:0]
def predict(args, model, dataset, prefix, log_fp=None):
'''
prediction
:param args:
:param model:
:param dataset:
:param prefix:
:param log_fp:
:return:
'''
output_dir = os.path.join(args.output_dir, prefix)
print("Testing info save dir: ", output_dir)
if not os.path.exists(output_dir) and args.local_rank in [-1, 0]:
os.makedirs(output_dir)
args.test_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly
test_dataset_total_num = len(dataset)
test_sampler = SequentialSampler(dataset)
test_dataloader = DataLoader(dataset, sampler=test_sampler, batch_size=args.test_batch_size)
test_batch_total_num = len(test_dataloader)
print("Test dataset len: %d, batch len: %d" % (test_dataset_total_num, test_batch_total_num))
# Multi GPU
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Eval!
logger.info("***** Running test {} *****".format(prefix))
logger.info("Num examples = %d", test_dataset_total_num)
logger.info("Batch size = %d", args.test_batch_size)
if log_fp:
log_fp.write("***** Running testing {} *****\n".format(prefix))
log_fp.write("Test Dataset Num examples = %d\n" % test_dataset_total_num)
log_fp.write("Test Dataset Instantaneous batch size per GPU = %d\n" % args.per_gpu_eval_batch_size)
log_fp.write("Test Dataset batch number = %d\n" % test_batch_total_num)
log_fp.write("#" * 50 + "\n")
test_loss = 0.0
nb_test_steps = 0
# predicted prob
pred_scores = None
# ground truth
out_label_ids = None
for batch in tqdm(test_dataloader, total=test_batch_total_num, desc="Testing"):
model.eval()
batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad():
inputs = {
"x": batch[0],
"labels": batch[1],
"lengths": batch[2]
}
outputs = model(**inputs)
tmp_test_loss, logits, output = outputs[:3]
test_loss += tmp_test_loss.mean().item()
nb_test_steps += 1
if pred_scores is None:
pred_scores = output.detach().cpu().numpy()
out_label_ids = inputs["labels"].detach().cpu().numpy()
else:
pred_scores = np.append(pred_scores, output.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
test_loss = test_loss / nb_test_steps
if args.output_mode in ["multi_class", "multi-class"]:
label_list = load_labels(filepath=args.label_filepath, header=True)
pred_label_names = label_id_2_label_name(args.output_mode, label_list=label_list, prob=pred_scores, threshold=0.5)
true_label_names = [label_list[idx] for idx in out_label_ids]
elif args.output_mode == "regression":
preds = np.squeeze(pred_scores)
pred_label_names = list(preds)
true_label_names = list(out_label_ids)
elif args.output_mode in ["multi_label", "multi-label"]:
label_list = load_labels(filepath=args.label_filepath, header=True)
pred_label_names = label_id_2_label_name(args.output_mode, label_list=label_list, prob=pred_scores, threshold=0.5)
true_label_names = label_id_2_label_name(args.output_mode, label_list=label_list, prob=out_label_ids, threshold=0.5)
elif args.output_mode in ["binary_class", "binary-class"]:
label_list = load_labels(filepath=args.label_filepath, header=True)
pred_label_names = label_id_2_label_name(args.output_mode, label_list=label_list, prob=pred_scores, threshold=0.5)
true_label_names = label_id_2_label_name(args.output_mode, label_list=label_list, prob=out_label_ids, threshold=0.5)
if args.output_mode in ["multi_class", "multi-class"]:
result = metrics_multi_class(out_label_ids, pred_scores)
elif args.output_mode in ["multi_label", "multi-label"]:
result = metrics_multi_label(out_label_ids, pred_scores, threshold=0.5)
elif args.output_mode == "regression":
pass # to do
elif args.output_mode in ["binary_class", "binary-class"]:
result = metrics_binary(out_label_ids, pred_scores, threshold=0.5,
savepath=os.path.join(output_dir, "test_confusion_matrix.png"))
with open(os.path.join(output_dir, "test_results.txt"), "w") as wfp:
for idx in range(len(pred_label_names)):
wfp.write("%d,%s,%s\n" %(idx, str(pred_label_names[idx]), str(true_label_names[idx])))
with open(os.path.join(output_dir, "test_metrics.txt"), "w") as wfp:
logger.info("***** Eval Test results {} *****".format(prefix))
for key in sorted(result.keys()):
logger.info("%s = %s", key, str(result[key]))
wfp.write("%s = %s\n" % (key, str(result[key])))
logger.info("Test metrics: ")
logger.info(json.dumps(result, ensure_ascii=False))
logger.info("")
return pred_label_names, true_label_names, result