in python/dglke/train_pytorch.py [0:0]
def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
if len(args.gpu) > 0:
gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
else:
gpu_id = -1
if args.strict_rel_part or args.soft_rel_part:
model.load_relation(th.device('cuda:' + str(gpu_id)))
if args.dataset == "wikikg90M":
with th.no_grad():
logs = []
answers = []
for sampler in test_samplers:
for query, ans, candidate in sampler:
model.forward_test_wikikg(query, ans, candidate, mode, logs, gpu_id)
answers.append(ans)
print("[{}] finished {} forward".format(rank, mode))
for i in range(len(test_samplers)):
test_samplers[i] = test_samplers[i].reset()
if mode == "Valid":
metrics = {}
if len(logs) > 0:
for metric in logs[0].keys():
metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
if queue is not None:
queue.put(logs)
else:
for k, v in metrics.items():
print('[{}]{} average {}: {}'.format(rank, mode, k, v))
else:
input_dict = {}
input_dict['h,r->t'] = {'t_correct_index': th.cat(answers, 0), 't_pred_top10': th.cat(logs, 0)}
th.save(input_dict, os.path.join(args.save_path, "test_{}.pkl".format(rank)))
else:
with th.no_grad():
logs = []
for sampler in test_samplers:
for pos_g, neg_g in sampler:
model.forward_test(pos_g, neg_g, logs, gpu_id)
metrics = {}
if len(logs) > 0:
for metric in logs[0].keys():
metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
if queue is not None:
queue.put(logs)
else:
for k, v in metrics.items():
print('[{}]{} average {}: {}'.format(rank, mode, k, v))
test_samplers[0] = test_samplers[0].reset()
test_samplers[1] = test_samplers[1].reset()