in pytorch_transformers/modeling_distilbert.py [0:0]
def forward(self, input_ids, attention_mask=None,
start_positions=None, end_positions=None, switches=None,
head_mask=None, global_step=None):
distilbert_output = self.bert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) # (bs, max_query_len)
end_logits = end_logits.squeeze(-1) # (bs, max_query_len)
switch_logits = self.qa_classifier(torch.max(hidden_states, 1)[0])
outputs = (start_logits, end_logits, switch_logits) + distilbert_output[1:]
if start_positions is not None and end_positions is not None and switches is not None:
assert global_step is not None
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
if len(switches.size()) > 1:
switches = switches.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
# masking
answer_mask = answer_mask.type(torch.FloatTensor).to(self.device)
span_mask = answer_mask * (switches==0).type(torch.FloatTensor).to(self.device)
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index, reduce=False)
start_losses = [loss_fct(start_logits, _start_positions) * _span_mask \
for (_start_positions, _span_mask) \
in zip(torch.unbind(start_positions, dim=1),
torch.unbind(span_mask, dim=1))]
end_losses = [loss_fct(end_logits, _end_positions) * _span_mask \
for (_end_positions, _span_mask) \
in zip(torch.unbind(end_positions, dim=1),
torch.unbind(span_mask, dim=1))]
switch_losses = [loss_fct(switch_logits, _switch) * _answer_mask \
for (_switch, _answer_mask) \
in zip(torch.unbind(switches, dim=1),
torch.unbind(answer_mask, dim=1))]
loss_tensor = torch.cat([t.unsqueeze(1) for t in start_losses], dim=1) + \
torch.cat([t.unsqueeze(1) for t in end_losses], dim=1) + \
torch.cat([t.unsqueeze(1) for t in switch_losses], dim=1)
if self.loss_type=='first-only':
total_loss = torch.sum(start_losses[0]+end_losses[0]+switch_losses[0])
elif self.loss_type=='hard-em':
if numpy.random.random()<min(global_step/self.tau, 0.8):
total_loss = self._take_min(loss_tensor)
else:
total_loss = self._take_mml(loss_tensor)
elif self.loss_type=='mml':
total_loss = self._take_mml(loss_tensor)
else:
raise NotImplementedError()
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)