in python/dglke/infer_score.py [0:0]
def main():
args = ArgParser().parse_args()
config = load_model_config(os.path.join(args.model_path, 'config.json'))
emap_file = args.entity_mfile
rmap_file = args.rel_mfile
data_files = args.data_files
# parse input data first
if args.format == 'h_r_t':
if args.raw_data:
assert emap_file is not None, 'When using RAW ID through --raw_data, ' \
'entity_mfile should be provided.'
assert rmap_file is not None, 'When using RAW ID through --raw_data, ' \
'rel_mfile should be provided.'
assert len(data_files) == 3, 'When using h_r_t, head.list, rel.list and tail.list ' \
'should be provided.'
head, rel, tail, id2e_map, id2r_map = load_raw_triplet_data(head_f=data_files[0],
rel_f=data_files[1],
tail_f=data_files[2],
emap_f=emap_file,
rmap_f=rmap_file)
else:
head, rel, tail = load_triplet_data(head_f=data_files[0],
rel_f=data_files[1],
tail_f=data_files[2])
elif args.format == 'h_r_*':
if args.raw_data:
assert emap_file is not None, 'When using RAW ID through --raw_data, ' \
'entity_mfile should be provided.'
assert rmap_file is not None, 'When using RAW ID through --raw_data, ' \
'rel_mfile should be provided.'
assert len(data_files) == 2, 'When using h_r_*, head.list and rel.list ' \
'should be provided.'
head, rel, tail, id2e_map, id2r_map = load_raw_triplet_data(head_f=data_files[0],
rel_f=data_files[1],
tail_f=None,
emap_f=emap_file,
rmap_f=rmap_file)
else:
head, rel, tail = load_triplet_data(head_f=data_files[0],
rel_f=data_files[1],
tail_f=None)
elif args.format == 'h_*_t':
if args.raw_data:
assert emap_file is not None, 'When using RAW ID through --raw_data, ' \
'entity_mfile should be provided.'
assert rmap_file is not None, 'When using RAW ID through --raw_data, ' \
'rel_mfile should be provided.'
assert len(data_files) == 2, 'When using h_*_t, head.list and tail.list ' \
'should be provided.'
head, rel, tail, id2e_map, id2r_map = load_raw_triplet_data(head_f=data_files[0],
rel_f=None,
tail_f=data_files[1],
emap_f=emap_file,
rmap_f=rmap_file)
else:
head, rel, tail = load_triplet_data(head_f=data_files[0],
rel_f=None,
tail_f=data_files[1])
elif args.format == '*_r_t':
if args.raw_data:
assert emap_file is not None, 'When using RAW ID through --raw_data, ' \
'entity_mfile should be provided.'
assert rmap_file is not None, 'When using RAW ID through --raw_data, ' \
'rel_mfile should be provided.'
assert len(data_files) == 2, 'When using *_r_t rel.list and tail.list ' \
'should be provided.'
head, rel, tail, id2e_map, id2r_map = load_raw_triplet_data(head_f=None,
rel_f=data_files[0],
tail_f=data_files[1],
emap_f=emap_file,
rmap_f=rmap_file)
else:
head, rel, tail = load_triplet_data(head_f=None,
rel_f=data_files[0],
tail_f=data_files[1])
elif args.format == 'h_*_*':
if args.raw_data:
assert emap_file is not None, 'When using RAW ID through --raw_data, ' \
'entity_mfile should be provided.'
assert rmap_file is not None, 'When using RAW ID through --raw_data, ' \
'rel_mfile should be provided.'
assert len(data_files) == 1, 'When using h_*_*, only head.list should be provided.'
head, rel, tail, id2e_map, id2r_map = load_raw_triplet_data(head_f=data_files[0],
rel_f=None,
tail_f=None,
emap_f=emap_file,
rmap_f=rmap_file)
else:
head, rel, tail = load_triplet_data(head_f=data_files[0],
rel_f=None,
tail_f=None)
elif args.format == '*_r_*':
if args.raw_data:
assert emap_file is not None, 'When using RAW ID through --raw_data, ' \
'entity_mfile should be provided.'
assert rmap_file is not None, 'When using RAW ID through --raw_data, ' \
'rel_mfile should be provided.'
assert len(data_files) == 1, 'When using *_r_*, only rel.list should be provided.'
head, rel, tail, id2e_map, id2r_map = load_raw_triplet_data(head_f=None,
rel_f=data_files[0],
tail_f=None,
emap_f=emap_file,
rmap_f=rmap_file)
else:
head, rel, tail = load_triplet_data(head_f=None,
rel_f=data_files[0],
tail_f=None)
elif args.format == '*_*_t':
if args.raw_data:
assert emap_file is not None, 'When using RAW ID through --raw_data, ' \
'entity_mfile should be provided.'
assert rmap_file is not None, 'When using RAW ID through --raw_data, ' \
'rel_mfile should be provided.'
assert len(data_files) == 1, 'When using *_*_t, only tail.list should be provided.'
head, rel, tail, id2e_map, id2r_map = load_raw_triplet_data(head_f=None,
rel_f=None,
tail_f=data_files[0],
emap_f=emap_file,
rmap_f=rmap_file)
else:
head, rel, tail = load_triplet_data(head_f=None,
rel_f=None,
tail_f=data_files[0])
else:
assert False, "Unsupported format {}".format(args.format)
model = ScoreInfer(args.gpu, config, args.model_path, args.score_func)
model.load_model()
result = model.topK(head, rel, tail, args.exec_mode, args.topK)
with open(args.output, 'w+') as f:
f.write('head\trel\ttail\tscore\n')
for res in result:
hl, rl, tl, sl = res
hl = hl.tolist()
rl = rl.tolist()
tl = tl.tolist()
sl = sl.tolist()
for h, r, t, s in zip(hl, rl, tl, sl):
if args.raw_data:
h = id2e_map[h]
r = id2r_map[r]
t = id2e_map[t]
f.write('{}\t{}\t{}\t{}\n'.format(h, r, t, s))
print('Inference Done')
print('The result is saved in {}'.format(args.output))