in train_extractive_reader.py [0:0]
def _calc_loss(self, input: ReaderBatch) -> torch.Tensor:
cfg = self.cfg
input = ReaderBatch(**move_to_device(input._asdict(), cfg.device))
attn_mask = self.tensorizer.get_attn_mask(input.input_ids)
questions_num, passages_per_question, _ = input.input_ids.size()
if self.reader.training:
# start_logits, end_logits, rank_logits = self.reader(input.input_ids, attn_mask)
loss = self.reader(
input.input_ids,
attn_mask,
input.start_positions,
input.end_positions,
input.answers_mask,
)
else:
# TODO: remove?
with torch.no_grad():
start_logits, end_logits, rank_logits = self.reader(input.input_ids, attn_mask)
loss = compute_loss(
input.start_positions,
input.end_positions,
input.answers_mask,
start_logits,
end_logits,
rank_logits,
questions_num,
passages_per_question,
)
if cfg.n_gpu > 1:
loss = loss.mean()
if cfg.train.gradient_accumulation_steps > 1:
loss = loss / cfg.train.gradient_accumulation_steps
return loss