def main()

in python/dglke/infer_emb_sim.py [0:0]


def main():
    args = ArgParser().parse_args()
    assert args.emb_file != None, 'emb_file should be provided for entity embeddings'

    data_files = args.data_files
    if args.format == 'l_r':
        if args.raw_data:
            head, id2e_map, e2id_map = load_raw_emb_data(file=data_files[0],
                                                         map_f=args.mfile)
            tail, _, _ = load_raw_emb_data(file=data_files[1],
                                           e2id_map=e2id_map)
        else:
            head = load_entity_data(data_files[0])
            tail = load_entity_data(data_files[1])
    elif args.format == 'l_*':
        if args.raw_data:
            head, id2e_map, e2id_map = load_raw_emb_data(file=data_files[0],
                                                         map_f=args.mfile)
        else:
            head = load_entity_data(data_files[0])
        tail = load_entity_data()
    elif args.format == '*_r':
        if args.raw_data:
            tail, id2e_map, e2id_map = load_raw_emb_data(file=data_files[0],
                                                         map_f=args.mfile)
        else:
            tail = load_entity_data(data_files[0])
        head = load_entity_data()
    elif args.format == '*':
        if args.raw_data:
            id2e_map = load_raw_emb_mapping(map_f=args.mfile)
        head = load_entity_data()
        tail = load_entity_data()

    if args.exec_mode == 'pairwise':
        pairwise = True
        bcast = False
    elif args.exec_mode == 'batch_left':
        pairwise = False
        bcast = True
    elif args.exec_mode == 'all':
        pairwise = False
        bcast = False
    else:
        assert False, 'Unknow execution model'

    model = EmbSimInfer(args.gpu, args.emb_file, args.sim_func)
    model.load_emb()
    result = model.topK(head, tail, bcast=bcast, pair_ws=pairwise, k=args.topK)

    with open(args.output, 'w+') as f:
        f.write('left\tright\tscore\n')
        for res in result:
            hl, tl, sl = res
            hl = hl.tolist()
            tl = tl.tolist()
            sl = sl.tolist()

            for h, t, s in zip(hl, tl, sl):
                if args.raw_data:
                    h = id2e_map[h]
                    t = id2e_map[t]
                f.write('{}\t{}\t{}\n'.format(h, t, s))
    print('Inference Done')
    print('The result is saved in {}'.format(args.output))