in src/model.py [0:0]
def forward(self,
question_ids,
question_mask,
passage_ids,
passage_mask,
gold_score=None):
question_output = self.embed_text(
text_ids=question_ids,
text_mask=question_mask,
apply_mask=self.config.apply_question_mask,
extract_cls=self.config.extract_cls,
)
bsz, n_passages, plen = passage_ids.size()
passage_ids = passage_ids.view(bsz * n_passages, plen)
passage_mask = passage_mask.view(bsz * n_passages, plen)
passage_output = self.embed_text(
text_ids=passage_ids,
text_mask=passage_mask,
apply_mask=self.config.apply_passage_mask,
extract_cls=self.config.extract_cls,
)
score = torch.einsum(
'bd,bid->bi',
question_output,
passage_output.view(bsz, n_passages, -1)
)
score = score / np.sqrt(question_output.size(-1))
if gold_score is not None:
loss = self.kldivloss(score, gold_score)
else:
loss = None
return question_output, passage_output, score, loss