in python/dglke/train_mxnet.py [0:0]
def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
assert args.num_proc <= 1, "MXNet KGE does not support multi-process now"
logs = []
if len(args.gpu) > 0:
gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
else:
gpu_id = -1
if args.strict_rel_part:
model.load_relation(mx.gpu(gpu_id))
for sampler in test_samplers:
#print('Number of tests: ' + len(sampler))
count = 0
for pos_g, neg_g in sampler:
model.forward_test(pos_g, neg_g, logs, gpu_id)
metrics = {}
if len(logs) > 0:
for metric in logs[0].keys():
metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
for k, v in metrics.items():
print('{} average {}: {}'.format(mode, k, v))
for i in range(len(test_samplers)):
test_samplers[i] = test_samplers[i].reset()