def forward()

in relogic/pretrainkit/models/semparse/logical_tabart.py [0:0]


  def forward(self, *input, **kwargs):
    input_ids = kwargs.pop("input_ids")
    pad_token_id = kwargs.pop("pad_token_id")
    attention_mask = (input_ids != pad_token_id).long()


    if self.training:
      task = kwargs.pop("task")
      if task == "text2sql":
        copy_span = None
        column_spans = kwargs.pop("column_spans")
        label_ids = kwargs.pop("labels")
        label_padding_id = kwargs.pop("label_padding_id")
        # encoded = self.bert.encoder(input_token_ids)[0].contiguous()
        y_ids = label_ids[:, :-1].contiguous()
        lm_labels = label_ids[:, 1:].clone()
        lm_labels[label_ids[:, 1:] == label_padding_id] = -100
        outputs = self.bert_for_texttosql(input_ids, column_spans=column_spans, copy_span=copy_span,
                            attention_mask=attention_mask, decoder_input_ids=y_ids, lm_labels=lm_labels, )
        return (outputs[0],)

      if task == "mlm" or task == "col_rev":
        output_ids = kwargs.pop('labels')
        y_ids = output_ids[:, :-1].contiguous()
        lm_labels = output_ids[:, 1:].clone()
        lm_labels[output_ids[:, 1:] == pad_token_id] = -100
        outputs = self.bert(input_ids,
                            attention_mask=attention_mask, decoder_input_ids=y_ids, labels=lm_labels, )
        return (outputs[0],)

      if task == "recurring_mlm":
        y_ids = kwargs.pop("y_ids")
        output_ids = kwargs.pop('labels')
        lm_labels = output_ids[:, 1:].clone()
        lm_labels[output_ids[:, 1:] == pad_token_id] = -100
        outputs = self.bert(input_ids,
                            attention_mask=attention_mask, decoder_input_ids=y_ids, labels=lm_labels, )
        return (outputs[0],)

      if task == "col_pred":
        label_ids = kwargs.pop("labels")
        column_spans = kwargs.pop("column_spans")
        column_selection_prob = self.column_prediction(input_ids, attention_mask, column_spans)
        label_mask = column_spans.view(-1, 2)[:,0] > 0

        column_selection_loss = F.binary_cross_entropy(column_selection_prob.view(-1)[label_mask], label_ids.view(-1)[label_mask].float(),
                                                       reduction="sum") / label_ids.size(0)
        # column_selection_loss = F.cross_entropy(column_selection_prob.view(-1, 3)[label_mask],
        #                                         label_ids.view(-1)[label_mask],
        #                                         reduction="sum") / label_ids.size(0)
        return (column_selection_loss, )

      if task == "col_type":
        label_ids = kwargs.pop("labels")
        column_spans = kwargs.pop("column_spans")
        column_selection_prob = self.column_classification(input_ids, attention_mask, column_spans)
        label_mask = column_spans.view(-1, 2)[:,0] > 0

        column_selection_loss = F.cross_entropy(column_selection_prob.view(-1, 3)[label_mask],
                                                label_ids.view(-1)[label_mask],
                                                reduction="sum") / label_ids.size(0)
        return (column_selection_loss, )

      if task == "value_pred":
        label_ids = kwargs.pop("labels")
        column_spans = kwargs.pop("column_spans")
        value_spans = kwargs.pop("value_spans")
        column_selection_prob = self.value_prediction(input_ids, attention_mask, column_spans, value_spans)
        label_mask = column_spans.view(-1, 2)[:, 0] > 0

        column_selection_loss = F.binary_cross_entropy(column_selection_prob.view(-1)[label_mask],
                                                       label_ids.view(-1)[label_mask].float(),
                                                       reduction="sum") / label_ids.size(0)
        return (column_selection_loss,)

      if task == "table_pred":
        label_ids = kwargs.pop("labels")
        table_prediction_prob = self.table_pred(input_ids, attention_mask)

        table_prediction_loss = F.cross_entropy(table_prediction_prob.view(-1, 2),
                                                         label_ids.view(-1),
                                                       reduction="sum") / label_ids.size(0)
        return (table_prediction_loss,)


      raise NotImplementedError("Unknown task {}".format(task))

    else:
      task = kwargs.pop("task")
      if task == "text2sql":
        copy_span = None
        column_spans = kwargs.pop("column_spans")
        label_eos_id = kwargs.pop("label_eos_id")
        label_bos_id = kwargs.pop("label_bos_id")
        label_padding_id = kwargs.pop("label_padding_id")
        generated_ids = self.bert_for_texttosql.generate(
          input_ids=input_ids,
          column_spans=column_spans,
          copy_span=copy_span,
          attention_mask=attention_mask,
          num_beams=1,
          max_length=30,
          length_penalty=2.0,
          early_stopping=True,
          use_cache=True,
          decoder_start_token_id=label_bos_id,
          eos_token_id=label_eos_id,
          pad_token_id=label_padding_id,
          vocab_size=len(KEYWORDS)
        )

        output_ids = kwargs.pop("labels")
        y_ids = output_ids[:, :-1].contiguous()
        lm_labels = output_ids[:, 1:].clone()
        lm_labels[output_ids[:, 1:] == label_padding_id] = -100
        outputs = self.bert_for_texttosql(input_ids, column_spans=column_spans,
                            attention_mask=attention_mask, decoder_input_ids=y_ids, labels=lm_labels, )

        return (outputs[0].detach(), generated_ids)

      if task == "recurring_mlm":
        label_eos_id = kwargs.pop("label_eos_id")
        label_bos_id = kwargs.pop("label_bos_id")
        label_padding_id = kwargs.pop("label_padding_id")
        generated_ids = self.bert.generate(
          input_ids=input_ids,
          attention_mask=attention_mask,
          num_beams=3,
          max_length=input_ids.size(1) + 5,
          length_penalty=2.0,
          early_stopping=True,
          use_cache=True,
          decoder_start_token_id=label_bos_id,
          eos_token_id=label_eos_id,
          pad_token_id=label_padding_id
        )
        generated_ids = generated_ids[:, 1:].contiguous()
        y_ids = kwargs.pop("y_ids")
        output_ids = kwargs.pop('labels')
        lm_labels = output_ids[:, 1:].clone()
        lm_labels[output_ids[:, 1:] == pad_token_id] = -100
        outputs = self.bert(input_ids,
                            attention_mask=attention_mask, decoder_input_ids=y_ids, labels=lm_labels, )

        return (outputs[0].detach(), generated_ids)

      if task == "mlm" or task == "col_rev":
        label_eos_id = kwargs.pop("label_eos_id")
        label_bos_id = kwargs.pop("label_bos_id")
        label_padding_id = kwargs.pop("label_padding_id")
        generated_ids = self.bert.generate(
          input_ids=input_ids,
          attention_mask=attention_mask,
          num_beams=3,
          max_length=input_ids.size(1) + 5,
          length_penalty=2.0,
          early_stopping=True,
          use_cache=True,
          decoder_start_token_id=label_bos_id,
          eos_token_id=label_eos_id,
          pad_token_id=label_padding_id
        )
        generated_ids = generated_ids[:,1:].contiguous()
        output_ids = kwargs.pop('labels')
        y_ids = output_ids[:, :-1].contiguous()
        lm_labels = output_ids[:, 1:].clone()
        lm_labels[output_ids[:, 1:] == label_padding_id] = -100

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask, decoder_input_ids=y_ids, labels=lm_labels, )
        return (outputs[0].detach(), generated_ids)

      if task == "col_pred":
        label_ids = kwargs.pop("labels")
        column_spans = kwargs.pop("column_spans")
        column_selection_prob = self.column_prediction(input_ids, attention_mask, column_spans)

        generated_ids = (column_selection_prob.squeeze(-1) > 0.5).long()
        generated_ids[column_spans[:,:,0]==0] = -100

        label_mask = column_spans.view(-1, 2)[:, 0] > 0

        column_selection_loss = F.binary_cross_entropy(column_selection_prob.view(-1)[label_mask],
                                                       label_ids.view(-1)[label_mask].float(),
                                                       reduction="sum") / label_ids.size(0)
        # column_selection_loss = F.cross_entropy(column_selection_prob.view(-1, 3)[label_mask],
        #                                         label_ids.view(-1)[label_mask],
        #                                         reduction="sum") / label_ids.size(0)
        return (column_selection_loss.detach(), generated_ids)

      if task == "col_type":
        label_ids = kwargs.pop("labels")
        column_spans = kwargs.pop("column_spans")
        column_selection_prob = self.column_prediction(input_ids, attention_mask, column_spans)

        generated_ids = column_selection_prob.argmax(dim=-1)
        generated_ids[column_spans[:, :, 0] == 0] = -100

        label_mask = column_spans.view(-1, 2)[:, 0] > 0
        column_selection_loss = F.cross_entropy(column_selection_prob.view(-1, 3)[label_mask],
                                                label_ids.view(-1)[label_mask],
                                                reduction="sum") / label_ids.size(0)
        return (column_selection_loss.detach(), generated_ids)

      if task == "value_pred":
        label_ids = kwargs.pop("labels")
        column_spans = kwargs.pop("column_spans")
        value_spans = kwargs.pop("value_spans")
        column_selection_prob = self.value_prediction(input_ids, attention_mask, column_spans, value_spans)

        generated_ids = (column_selection_prob.squeeze(-1) > 0.5).long()
        generated_ids[column_spans[:, :, 0] == 0] = -100

        label_mask = column_spans.view(-1, 2)[:, 0] > 0

        column_selection_loss = F.binary_cross_entropy(column_selection_prob.view(-1)[label_mask],
                                                       label_ids.view(-1)[label_mask].float(),
                                                       reduction="sum") / label_ids.size(0)
        return (column_selection_loss.detach(), generated_ids)

      if task == "table_pred":
        label_ids = kwargs.pop("labels")
        table_prediction_prob = self.table_pred(input_ids, attention_mask)
        generated_ids = table_prediction_prob.argmax(dim=-1).unsqueeze(-1)
        table_prediction_loss = F.cross_entropy(table_prediction_prob.view(-1, 2),
                                                       label_ids.view(-1),
                                                       reduction="sum") / label_ids.size(0)
        return (table_prediction_loss.detach(), generated_ids)

      raise NotImplementedError()