in src/datasets.py [0:0]
def eval(self,
model: KBCModel, split: str,
n_queries: int = -1,
n_epochs: int = -1,
query_type: str = 'both', at: Tuple[int] = (1, 3, 10)):
print('Evaluate the split {}'.format(split))
test = self.get_examples(split)
examples = torch.from_numpy(test).to(self.device)
query_types = ['rhs', 'lhs'] if query_type == 'both' else [query_type]
res, mean_reciprocal_rank, hits_at = {}, {}, {}
for m in query_types:
print('Evaluating the {}'.format(m))
q = examples.clone()
if n_queries > 0: # used to sample a subset of train,
q = subsample(examples, n_queries)
candidate_pos = m
if m == 'lhs':
if self.reciprocal:
q = invert_torch(q, self.n_predicates // 2, include_type=self.include_type)
candidate_pos = 'rhs' # after reversing, the candidates to score are at rhs
if 'ogb' in self.name:
evaluator = Evaluator(name=self.name)
metrics = model.get_metric_ogb(q,
batch_size=self.bsz_vt,
query_type=candidate_pos,
evaluator=evaluator)
mean_reciprocal_rank[m] = metrics['mrr_list']
hits_at[m] = torch.FloatTensor([metrics['hits@{}_list'.format(k)] for k in at])
res = None
else:
ranks, predicted = model.get_ranking(q, self.to_skip[m],
batch_size=self.bsz_vt,
candidates=candidate_pos)
mean_reciprocal_rank[m] = torch.mean(1. / ranks).item()
hits_at[m] = torch.FloatTensor((list(map(
lambda x: torch.mean((ranks <= x).float()).item(),
at
))))
res[m] = {'query': examples, # triples to compute rhs raking among all the entities
'rank': ranks,
'predicted': predicted}
del q
return mean_reciprocal_rank, hits_at, res