in src/models.py [0:0]
def get_ranking(self,
queries: torch.Tensor,
filters: Dict[Tuple[int, int], List[int]],
batch_size: int = 1000, chunk_size: int = -1,
candidates='rhs'):
"""
Returns filtered ranking for each queries.
:param queries: a torch.LongTensor of triples (lhs, rel, rhs)
:param filters: filters[(lhs, rel)] gives the rhs to filter from ranking
:param batch_size: maximum number of queries processed at once
:param chunk_size: maximum number of answering candidates processed at once
:return:
"""
query_type = candidates
if chunk_size < 0: # not chunking, score against all candidates at once
chunk_size = self.sizes[2] # entity ranking
ranks = torch.ones(len(queries))
predicted = torch.zeros(len(queries))
with torch.no_grad():
c_begin = 0
while c_begin < self.sizes[2]:
b_begin = 0
cands = self.get_candidates(c_begin, chunk_size, target=query_type)
while b_begin < len(queries):
these_queries = queries[b_begin:b_begin + batch_size]
q = self.get_queries(these_queries, target=query_type)
scores = q @ cands # torch.mv MIPS
targets = self.score(these_queries)
if filters is not None:
scores = filtering(scores, these_queries, filters,
n_rel=self.sizes[1], n_ent=self.sizes[2],
c_begin=c_begin, chunk_size=chunk_size,
query_type=query_type)
ranks[b_begin:b_begin + batch_size] += torch.sum(
(scores >= targets).float(), dim=1
).cpu()
predicted[b_begin:b_begin + batch_size] = torch.max(scores, dim=1)[1].cpu()
b_begin += batch_size
c_begin += chunk_size
return ranks, predicted