in src/models.py [0:0]
def get_metric_ogb(self,
queries: torch.Tensor,
batch_size: int = 1000,
query_type='rhs',
evaluator=None):
"""No need to filter since the provided negatives are ready filtered
:param queries: a torch.LongTensor of triples (lhs, rel, rhs)
:param batch_size: maximum number of queries processed at once
:return:
"""
test_logs = defaultdict(list)
with torch.no_grad():
b_begin = 0
while b_begin < len(queries):
these_queries = queries[b_begin:b_begin + batch_size]
##### hard code neg_indice TODO
if these_queries.shape[1] > 5: # more than h,r,t,h_type,t_type
tot_neg = 1000 if evaluator.name in ['ogbl-biokg', 'ogbl-wikikg2'] else 0
neg_indices = these_queries[:, 3:3+tot_neg]
chunk_begin, chunk_size = None, None
else:
neg_indices = None
chunk_begin, chunk_size = 0, self.sizes[2] # all the entities
q = self.get_queries(these_queries, target=query_type)
cands = self.get_candidates(chunk_begin, chunk_size,
target=query_type,
indices=neg_indices)
if cands.dim() >= 3:# each example has a different negative candidate embedding matrix
scores = torch.bmm(cands, q.unsqueeze(-1)).squeeze(-1)
else:
scores = q @ cands # torch.mv MIPS, pos + neg scores
targets = self.score(these_queries) # positive scores
batch_results = evaluator.eval({'y_pred_pos': targets.squeeze(-1),
'y_pred_neg': scores})
del targets, scores, q, cands
for metric in batch_results:
test_logs[metric].append(batch_results[metric])
b_begin += batch_size
metrics = {}
for metric in test_logs:
metrics[metric] = torch.cat(test_logs[metric]).mean().item()
return metrics