in python/dglke/infer_emb_sim.py [0:0]
def main():
args = ArgParser().parse_args()
assert args.emb_file != None, 'emb_file should be provided for entity embeddings'
data_files = args.data_files
if args.format == 'l_r':
if args.raw_data:
head, id2e_map, e2id_map = load_raw_emb_data(file=data_files[0],
map_f=args.mfile)
tail, _, _ = load_raw_emb_data(file=data_files[1],
e2id_map=e2id_map)
else:
head = load_entity_data(data_files[0])
tail = load_entity_data(data_files[1])
elif args.format == 'l_*':
if args.raw_data:
head, id2e_map, e2id_map = load_raw_emb_data(file=data_files[0],
map_f=args.mfile)
else:
head = load_entity_data(data_files[0])
tail = load_entity_data()
elif args.format == '*_r':
if args.raw_data:
tail, id2e_map, e2id_map = load_raw_emb_data(file=data_files[0],
map_f=args.mfile)
else:
tail = load_entity_data(data_files[0])
head = load_entity_data()
elif args.format == '*':
if args.raw_data:
id2e_map = load_raw_emb_mapping(map_f=args.mfile)
head = load_entity_data()
tail = load_entity_data()
if args.exec_mode == 'pairwise':
pairwise = True
bcast = False
elif args.exec_mode == 'batch_left':
pairwise = False
bcast = True
elif args.exec_mode == 'all':
pairwise = False
bcast = False
else:
assert False, 'Unknow execution model'
model = EmbSimInfer(args.gpu, args.emb_file, args.sim_func)
model.load_emb()
result = model.topK(head, tail, bcast=bcast, pair_ws=pairwise, k=args.topK)
with open(args.output, 'w+') as f:
f.write('left\tright\tscore\n')
for res in result:
hl, tl, sl = res
hl = hl.tolist()
tl = tl.tolist()
sl = sl.tolist()
for h, t, s in zip(hl, tl, sl):
if args.raw_data:
h = id2e_map[h]
t = id2e_map[t]
f.write('{}\t{}\t{}\n'.format(h, t, s))
print('Inference Done')
print('The result is saved in {}'.format(args.output))