in hype_kg/codes/model.py [0:0]
def test_step(model, test_triples, test_ans, test_ans_hard, args):
qtype = test_triples[0][-1]
if qtype == 'chain-inter' or qtype == 'inter-chain' or qtype == 'union-chain':
rel_len = 2
else:
rel_len = int(test_triples[0][-1].split('-')[0])
model.eval()
if qtype == 'inter-chain' or qtype == 'union-chain':
test_dataloader_tail = DataLoader(
TestInterChainDataset(
test_triples,
test_ans,
test_ans_hard,
args.nentity,
args.nrelation,
'tail-batch'
),
batch_size=args.test_batch_size,
num_workers=max(1, args.cpu_num),
collate_fn=TestDataset.collate_fn
)
elif qtype == 'chain-inter':
test_dataloader_tail = DataLoader(
TestChainInterDataset(
test_triples,
test_ans,
test_ans_hard,
args.nentity,
args.nrelation,
'tail-batch'
),
batch_size=args.test_batch_size,
num_workers=max(1, args.cpu_num),
collate_fn=TestDataset.collate_fn
)
elif 'inter' in qtype or 'union' in qtype:
test_dataloader_tail = DataLoader(
TestInterDataset(
test_triples,
test_ans,
test_ans_hard,
args.nentity,
args.nrelation,
'tail-batch'
),
batch_size=args.test_batch_size,
num_workers=max(1, args.cpu_num),
collate_fn=TestDataset.collate_fn
)
else:
test_dataloader_tail = DataLoader(
TestDataset(
test_triples,
test_ans,
test_ans_hard,
args.nentity,
args.nrelation,
'tail-batch'
),
batch_size=args.test_batch_size,
num_workers=max(1, args.cpu_num),
collate_fn=TestDataset.collate_fn
)
test_dataset_list = [test_dataloader_tail]
# test_dataset_list = [test_dataloader_head, test_dataloader_tail]
step = 0
total_steps = sum([len(dataset) for dataset in test_dataset_list])
logs = []
with torch.no_grad():
for test_dataset in test_dataset_list:
for positive_sample, negative_sample, mode, query in test_dataset:
if args.cuda:
positive_sample = positive_sample.cuda()
negative_sample = negative_sample.cuda()
batch_size = positive_sample.size(0)
assert batch_size == 1, batch_size
if 'inter' in qtype:
if model.geo == 'box':
_, score_cen, _, score_cen_plus, _, _ = model((positive_sample, negative_sample), rel_len, qtype, mode=mode)
else:
score, score_cen, _, score_cen_plus, _, _ = model((positive_sample, negative_sample), rel_len, qtype, mode=mode)
else:
score, score_cen, _, score_cen_plus, _, _ = model((positive_sample, negative_sample), rel_len, qtype, mode=mode)
if model.geo == 'box':
score = score_cen
score2 = score_cen_plus
score -= (torch.min(score) - 1)
ans = test_ans[query]
hard_ans = test_ans_hard[query]
all_idx = set(range(args.nentity))
false_ans = all_idx - ans
ans_list = list(ans)
hard_ans_list = list(hard_ans)
false_ans_list = list(false_ans)
ans_idxs = np.array(hard_ans_list)
vals = np.zeros((len(ans_idxs), args.nentity))
vals[np.arange(len(ans_idxs)), ans_idxs] = 1
axis2 = np.tile(false_ans_list, len(ans_idxs))
axis1 = np.repeat(range(len(ans_idxs)), len(false_ans))
vals[axis1, axis2] = 1
score = score.cuda()
b = torch.Tensor(vals) if not args.cuda else torch.Tensor(vals).cuda()
filter_score = b*score
argsort = torch.argsort(filter_score, dim=1, descending=True)
ans_tensor = torch.LongTensor(hard_ans_list) if not args.cuda else torch.LongTensor(hard_ans_list).cuda()
argsort = torch.transpose(torch.transpose(argsort, 0, 1) - ans_tensor, 0, 1)
ranking = (argsort == 0).nonzero()
ranking = ranking[:, 1]
ranking = ranking + 1
if model.geo == 'box':
score2 -= (torch.min(score2) - 1)
score2 = score2.cuda()
b = b.cuda()
ans_tensor = ans_tensor.cuda()
filter_score2 = b*score2
argsort2 = torch.argsort(filter_score2, dim=1, descending=True)
argsort2 = torch.transpose(torch.transpose(argsort2, 0, 1) - ans_tensor, 0, 1)
ranking2 = (argsort2 == 0).nonzero()
ranking2 = ranking2[:, 1]
ranking2 = ranking2 + 1
ans_vec = np.zeros(args.nentity)
ans_vec[ans_list] = 1
hits1 = torch.sum((ranking <= 1).to(torch.float)).item()
hits3 = torch.sum((ranking <= 3).to(torch.float)).item()
hits10 = torch.sum((ranking <= 10).to(torch.float)).item()
mr = float(torch.sum(ranking).item())
mrr = torch.sum(1./ranking.to(torch.float)).item()
hits1m = torch.mean((ranking <= 1).to(torch.float)).item()
hits3m = torch.mean((ranking <= 3).to(torch.float)).item()
hits10m = torch.mean((ranking <= 10).to(torch.float)).item()
mrm = torch.mean(ranking.to(torch.float)).item()
mrrm = torch.mean(1./ranking.to(torch.float)).item()
num_ans = len(hard_ans_list)
if model.geo == 'box':
hits1m_newd = torch.mean((ranking2 <= 1).to(torch.float)).item()
hits3m_newd = torch.mean((ranking2 <= 3).to(torch.float)).item()
hits10m_newd = torch.mean((ranking2 <= 10).to(torch.float)).item()
mrm_newd = torch.mean(ranking2.to(torch.float)).item()
mrrm_newd = torch.mean(1./ranking2.to(torch.float)).item()
else:
hits1m_newd = hits1m
hits3m_newd = hits3m
hits10m_newd = hits10m
mrm_newd = mrm
mrrm_newd = mrrm
logs.append({
'MRRm_new': mrrm_newd,
'MRm_new': mrm_newd,
'HITS@1m_new': hits1m_newd,
'HITS@3m_new': hits3m_newd,
'HITS@10m_new': hits10m_newd,
'num_answer': num_ans
})
if step % args.test_log_steps == 0:
logging.info('Evaluating the model... (%d/%d)' % (step, total_steps))
step += 1
metrics = {}
num_answer = sum([log['num_answer'] for log in logs])
for metric in logs[0].keys():
if metric == 'num_answer':
continue
if 'm' in metric:
metrics[metric] = sum([log[metric] for log in logs])/len(logs)
else:
metrics[metric] = sum([log[metric] for log in logs])/num_answer
return metrics