def forward()

in pytorch_transformers/modeling_distilbert.py [0:0]


    def forward(self, input_ids, attention_mask=None,
                start_positions=None, end_positions=None, switches=None,
                head_mask=None, global_step=None):
        distilbert_output = self.bert(input_ids=input_ids,
                                      attention_mask=attention_mask,
                                      head_mask=head_mask)
        hidden_states = distilbert_output[0]                                 # (bs, max_query_len, dim)
        hidden_states = self.dropout(hidden_states)                       # (bs, max_query_len, dim)
        logits = self.qa_outputs(hidden_states)                           # (bs, max_query_len, 2)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)                           # (bs, max_query_len)
        end_logits = end_logits.squeeze(-1)                               # (bs, max_query_len)
        switch_logits = self.qa_classifier(torch.max(hidden_states, 1)[0])

        outputs = (start_logits, end_logits, switch_logits) + distilbert_output[1:]
        if start_positions is not None and end_positions is not None and switches is not None:
            assert global_step is not None
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            if len(switches.size()) > 1:
                switches = switches.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)
            # masking
            answer_mask = answer_mask.type(torch.FloatTensor).to(self.device)
            span_mask = answer_mask * (switches==0).type(torch.FloatTensor).to(self.device)

            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index, reduce=False)
            start_losses = [loss_fct(start_logits, _start_positions) * _span_mask \
                            for (_start_positions, _span_mask) \
                            in zip(torch.unbind(start_positions, dim=1),
                                   torch.unbind(span_mask, dim=1))]
            end_losses = [loss_fct(end_logits, _end_positions) * _span_mask \
                            for (_end_positions, _span_mask) \
                            in zip(torch.unbind(end_positions, dim=1),
                                   torch.unbind(span_mask, dim=1))]
            switch_losses = [loss_fct(switch_logits, _switch) * _answer_mask \
                             for (_switch, _answer_mask) \
                             in zip(torch.unbind(switches, dim=1),
                                    torch.unbind(answer_mask, dim=1))]

            loss_tensor = torch.cat([t.unsqueeze(1) for t in start_losses], dim=1) + \
                torch.cat([t.unsqueeze(1) for t in end_losses], dim=1) + \
                torch.cat([t.unsqueeze(1) for t in switch_losses], dim=1)
            if self.loss_type=='first-only':
                total_loss = torch.sum(start_losses[0]+end_losses[0]+switch_losses[0])
            elif self.loss_type=='hard-em':
                if numpy.random.random()<min(global_step/self.tau, 0.8):
                    total_loss = self._take_min(loss_tensor)
                else:
                    total_loss = self._take_mml(loss_tensor)
            elif self.loss_type=='mml':
                total_loss = self._take_mml(loss_tensor)
            else:
                raise NotImplementedError()
            outputs = (total_loss,) + outputs

        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)