in run_ranking.py [0:0]
def test_epoch_end(self, outputs):
logits = torch.cat([x['logits'] for x in outputs], dim=0).cpu()
scores = F.sigmoid(logits).squeeze(1).numpy().tolist() # (num_test, )
lines = []
count = 0
with jsonlines.open(os.path.join(self.hparams.data_dir, f"{self.hparams.question_type}.dev.es_retrieved.processed.jsonl")) as f:
for line in f.iter():
line['passages'] = line['passages'][:100]
line['tables'] = line['tables'][:100]
for cand in line['passages']:
cand['rank_score'] = scores[count]
count += 1
if count % 10000 == 0:
logger.info(f"Finished {count} lines")
for cand in line['tables']:
cand['rank_score'] = scores[count]
count += 1
if count % 10000 == 0:
logger.info(f"Finished {count} lines")
lines.append(line)
score_dir = os.path.join(self.hparams.data_dir, "scores")
if not os.path.exists(score_dir):
os.makedirs(score_dir)
with jsonlines.open(os.path.join(score_dir, f"{self.hparams.question_type}.dev.es_retrieved.processed.scores.jsonl"), 'w') as writer:
writer.write_all(lines)
return {'avg_test_loss': -1}