def main()

in python/dglke/infer_score.py [0:0]


def main():
    args = ArgParser().parse_args()

    config = load_model_config(os.path.join(args.model_path, 'config.json'))
    emap_file = args.entity_mfile
    rmap_file = args.rel_mfile

    data_files = args.data_files
    # parse input data first
    if args.format == 'h_r_t':
        if args.raw_data:
            assert emap_file is not None, 'When using RAW ID through --raw_data, ' \
                'entity_mfile should be provided.'
            assert rmap_file is not None, 'When using RAW ID through --raw_data, ' \
                'rel_mfile should be provided.'
            assert len(data_files) == 3, 'When using h_r_t, head.list, rel.list and tail.list ' \
                'should be provided.'
            head, rel, tail, id2e_map, id2r_map = load_raw_triplet_data(head_f=data_files[0],
                                                                        rel_f=data_files[1],
                                                                        tail_f=data_files[2],
                                                                        emap_f=emap_file,
                                                                        rmap_f=rmap_file)
        else:
            head, rel, tail = load_triplet_data(head_f=data_files[0],
                                                rel_f=data_files[1],
                                                tail_f=data_files[2])

    elif args.format == 'h_r_*':
        if args.raw_data:
            assert emap_file is not None, 'When using RAW ID through --raw_data, ' \
                'entity_mfile should be provided.'
            assert rmap_file is not None, 'When using RAW ID through --raw_data, ' \
                'rel_mfile should be provided.'
            assert len(data_files) == 2, 'When using h_r_*, head.list and rel.list ' \
                'should be provided.'
            head, rel, tail, id2e_map, id2r_map = load_raw_triplet_data(head_f=data_files[0],
                                                                        rel_f=data_files[1],
                                                                        tail_f=None,
                                                                        emap_f=emap_file,
                                                                        rmap_f=rmap_file)
        else:
            head, rel, tail = load_triplet_data(head_f=data_files[0],
                                                rel_f=data_files[1],
                                                tail_f=None)

    elif args.format == 'h_*_t':
        if args.raw_data:
            assert emap_file is not None, 'When using RAW ID through --raw_data, ' \
                'entity_mfile should be provided.'
            assert rmap_file is not None, 'When using RAW ID through --raw_data, ' \
                'rel_mfile should be provided.'
            assert len(data_files) == 2, 'When using h_*_t, head.list and tail.list ' \
                'should be provided.'
            head, rel, tail, id2e_map, id2r_map = load_raw_triplet_data(head_f=data_files[0],
                                                                        rel_f=None,
                                                                        tail_f=data_files[1],
                                                                        emap_f=emap_file,
                                                                        rmap_f=rmap_file)
        else:
            head, rel, tail = load_triplet_data(head_f=data_files[0],
                                                rel_f=None,
                                                tail_f=data_files[1])

    elif args.format == '*_r_t':
        if args.raw_data:
            assert emap_file is not None, 'When using RAW ID through --raw_data, ' \
                'entity_mfile should be provided.'
            assert rmap_file is not None, 'When using RAW ID through --raw_data, ' \
                'rel_mfile should be provided.'
            assert len(data_files) == 2, 'When using *_r_t rel.list and tail.list ' \
                'should be provided.'
            head, rel, tail, id2e_map, id2r_map = load_raw_triplet_data(head_f=None,
                                                                        rel_f=data_files[0],
                                                                        tail_f=data_files[1],
                                                                        emap_f=emap_file,
                                                                        rmap_f=rmap_file)
        else:
            head, rel, tail = load_triplet_data(head_f=None,
                                                rel_f=data_files[0],
                                                tail_f=data_files[1])

    elif args.format == 'h_*_*':
        if args.raw_data:
            assert emap_file is not None, 'When using RAW ID through --raw_data, ' \
                'entity_mfile should be provided.'
            assert rmap_file is not None, 'When using RAW ID through --raw_data, ' \
                'rel_mfile should be provided.'
            assert len(data_files) == 1, 'When using h_*_*, only head.list should be provided.'
            head, rel, tail, id2e_map, id2r_map = load_raw_triplet_data(head_f=data_files[0],
                                                                        rel_f=None,
                                                                        tail_f=None,
                                                                        emap_f=emap_file,
                                                                        rmap_f=rmap_file)
        else:
            head, rel, tail = load_triplet_data(head_f=data_files[0],
                                                rel_f=None,
                                                tail_f=None)

    elif args.format == '*_r_*':
        if args.raw_data:
            assert emap_file is not None, 'When using RAW ID through --raw_data, ' \
                'entity_mfile should be provided.'
            assert rmap_file is not None, 'When using RAW ID through --raw_data, ' \
                'rel_mfile should be provided.'
            assert len(data_files) == 1, 'When using *_r_*, only rel.list should be provided.'
            head, rel, tail, id2e_map, id2r_map = load_raw_triplet_data(head_f=None,
                                                                        rel_f=data_files[0],
                                                                        tail_f=None,
                                                                        emap_f=emap_file,
                                                                        rmap_f=rmap_file)
        else:
            head, rel, tail = load_triplet_data(head_f=None,
                                                rel_f=data_files[0],
                                                tail_f=None)

    elif args.format == '*_*_t':
        if args.raw_data:
            assert emap_file is not None, 'When using RAW ID through --raw_data, ' \
                'entity_mfile should be provided.'
            assert rmap_file is not None, 'When using RAW ID through --raw_data, ' \
                'rel_mfile should be provided.'
            assert len(data_files) == 1, 'When using *_*_t, only tail.list should be provided.'
            head, rel, tail, id2e_map, id2r_map = load_raw_triplet_data(head_f=None,
                                                                        rel_f=None,
                                                                        tail_f=data_files[0],
                                                                        emap_f=emap_file,
                                                                        rmap_f=rmap_file)
        else:
            head, rel, tail = load_triplet_data(head_f=None,
                                                rel_f=None,
                                                tail_f=data_files[0])

    else:
        assert False, "Unsupported format {}".format(args.format)

    model = ScoreInfer(args.gpu, config, args.model_path, args.score_func)
    model.load_model()
    result = model.topK(head, rel, tail, args.exec_mode, args.topK)

    with open(args.output, 'w+') as f:
        f.write('head\trel\ttail\tscore\n')
        for res in result:
            hl, rl, tl, sl = res
            hl = hl.tolist()
            rl = rl.tolist()
            tl = tl.tolist()
            sl = sl.tolist()

            for h, r, t, s in zip(hl, rl, tl, sl):
                if args.raw_data:
                    h = id2e_map[h]
                    r = id2r_map[r]
                    t = id2e_map[t]
                f.write('{}\t{}\t{}\t{}\n'.format(h, r, t, s))
    print('Inference Done')
    print('The result is saved in {}'.format(args.output))