def test()

in python/dglke/train_pytorch.py [0:0]


def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
    if len(args.gpu) > 0:
        gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
    else:
        gpu_id = -1

    if args.strict_rel_part or args.soft_rel_part:
        model.load_relation(th.device('cuda:' + str(gpu_id)))

    if args.dataset == "wikikg90M":
        with th.no_grad():
            logs = []
            answers = []
            for sampler in test_samplers:
                for query, ans, candidate in sampler:
                    model.forward_test_wikikg(query, ans, candidate, mode, logs, gpu_id)
                    answers.append(ans)
            print("[{}] finished {} forward".format(rank, mode))

            for i in range(len(test_samplers)):
                test_samplers[i] = test_samplers[i].reset()

            if mode == "Valid":
                metrics = {}
                if len(logs) > 0:
                    for metric in logs[0].keys():
                        metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
                if queue is not None:
                    queue.put(logs)
                else:
                    for k, v in metrics.items():
                        print('[{}]{} average {}: {}'.format(rank, mode, k, v))
            else:
                input_dict = {}
                input_dict['h,r->t'] = {'t_correct_index': th.cat(answers, 0), 't_pred_top10': th.cat(logs, 0)}
                th.save(input_dict, os.path.join(args.save_path, "test_{}.pkl".format(rank)))
    else:
        with th.no_grad():
            logs = []

            for sampler in test_samplers:
                for pos_g, neg_g in sampler:
                    model.forward_test(pos_g, neg_g, logs, gpu_id)

            metrics = {}
            if len(logs) > 0:
                for metric in logs[0].keys():
                    metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
            if queue is not None:
                queue.put(logs)
            else:
                for k, v in metrics.items():
                    print('[{}]{} average {}: {}'.format(rank, mode, k, v))
        test_samplers[0] = test_samplers[0].reset()
        test_samplers[1] = test_samplers[1].reset()