in relogic/pretrainkit/models/semparse/logical_tabart.py [0:0]
def forward(self, *input, **kwargs):
input_ids = kwargs.pop("input_ids")
pad_token_id = kwargs.pop("pad_token_id")
attention_mask = (input_ids != pad_token_id).long()
if self.training:
task = kwargs.pop("task")
if task == "text2sql":
copy_span = None
column_spans = kwargs.pop("column_spans")
label_ids = kwargs.pop("labels")
label_padding_id = kwargs.pop("label_padding_id")
# encoded = self.bert.encoder(input_token_ids)[0].contiguous()
y_ids = label_ids[:, :-1].contiguous()
lm_labels = label_ids[:, 1:].clone()
lm_labels[label_ids[:, 1:] == label_padding_id] = -100
outputs = self.bert_for_texttosql(input_ids, column_spans=column_spans, copy_span=copy_span,
attention_mask=attention_mask, decoder_input_ids=y_ids, lm_labels=lm_labels, )
return (outputs[0],)
if task == "mlm" or task == "col_rev":
output_ids = kwargs.pop('labels')
y_ids = output_ids[:, :-1].contiguous()
lm_labels = output_ids[:, 1:].clone()
lm_labels[output_ids[:, 1:] == pad_token_id] = -100
outputs = self.bert(input_ids,
attention_mask=attention_mask, decoder_input_ids=y_ids, labels=lm_labels, )
return (outputs[0],)
if task == "recurring_mlm":
y_ids = kwargs.pop("y_ids")
output_ids = kwargs.pop('labels')
lm_labels = output_ids[:, 1:].clone()
lm_labels[output_ids[:, 1:] == pad_token_id] = -100
outputs = self.bert(input_ids,
attention_mask=attention_mask, decoder_input_ids=y_ids, labels=lm_labels, )
return (outputs[0],)
if task == "col_pred":
label_ids = kwargs.pop("labels")
column_spans = kwargs.pop("column_spans")
column_selection_prob = self.column_prediction(input_ids, attention_mask, column_spans)
label_mask = column_spans.view(-1, 2)[:,0] > 0
column_selection_loss = F.binary_cross_entropy(column_selection_prob.view(-1)[label_mask], label_ids.view(-1)[label_mask].float(),
reduction="sum") / label_ids.size(0)
# column_selection_loss = F.cross_entropy(column_selection_prob.view(-1, 3)[label_mask],
# label_ids.view(-1)[label_mask],
# reduction="sum") / label_ids.size(0)
return (column_selection_loss, )
if task == "col_type":
label_ids = kwargs.pop("labels")
column_spans = kwargs.pop("column_spans")
column_selection_prob = self.column_classification(input_ids, attention_mask, column_spans)
label_mask = column_spans.view(-1, 2)[:,0] > 0
column_selection_loss = F.cross_entropy(column_selection_prob.view(-1, 3)[label_mask],
label_ids.view(-1)[label_mask],
reduction="sum") / label_ids.size(0)
return (column_selection_loss, )
if task == "value_pred":
label_ids = kwargs.pop("labels")
column_spans = kwargs.pop("column_spans")
value_spans = kwargs.pop("value_spans")
column_selection_prob = self.value_prediction(input_ids, attention_mask, column_spans, value_spans)
label_mask = column_spans.view(-1, 2)[:, 0] > 0
column_selection_loss = F.binary_cross_entropy(column_selection_prob.view(-1)[label_mask],
label_ids.view(-1)[label_mask].float(),
reduction="sum") / label_ids.size(0)
return (column_selection_loss,)
if task == "table_pred":
label_ids = kwargs.pop("labels")
table_prediction_prob = self.table_pred(input_ids, attention_mask)
table_prediction_loss = F.cross_entropy(table_prediction_prob.view(-1, 2),
label_ids.view(-1),
reduction="sum") / label_ids.size(0)
return (table_prediction_loss,)
raise NotImplementedError("Unknown task {}".format(task))
else:
task = kwargs.pop("task")
if task == "text2sql":
copy_span = None
column_spans = kwargs.pop("column_spans")
label_eos_id = kwargs.pop("label_eos_id")
label_bos_id = kwargs.pop("label_bos_id")
label_padding_id = kwargs.pop("label_padding_id")
generated_ids = self.bert_for_texttosql.generate(
input_ids=input_ids,
column_spans=column_spans,
copy_span=copy_span,
attention_mask=attention_mask,
num_beams=1,
max_length=30,
length_penalty=2.0,
early_stopping=True,
use_cache=True,
decoder_start_token_id=label_bos_id,
eos_token_id=label_eos_id,
pad_token_id=label_padding_id,
vocab_size=len(KEYWORDS)
)
output_ids = kwargs.pop("labels")
y_ids = output_ids[:, :-1].contiguous()
lm_labels = output_ids[:, 1:].clone()
lm_labels[output_ids[:, 1:] == label_padding_id] = -100
outputs = self.bert_for_texttosql(input_ids, column_spans=column_spans,
attention_mask=attention_mask, decoder_input_ids=y_ids, labels=lm_labels, )
return (outputs[0].detach(), generated_ids)
if task == "recurring_mlm":
label_eos_id = kwargs.pop("label_eos_id")
label_bos_id = kwargs.pop("label_bos_id")
label_padding_id = kwargs.pop("label_padding_id")
generated_ids = self.bert.generate(
input_ids=input_ids,
attention_mask=attention_mask,
num_beams=3,
max_length=input_ids.size(1) + 5,
length_penalty=2.0,
early_stopping=True,
use_cache=True,
decoder_start_token_id=label_bos_id,
eos_token_id=label_eos_id,
pad_token_id=label_padding_id
)
generated_ids = generated_ids[:, 1:].contiguous()
y_ids = kwargs.pop("y_ids")
output_ids = kwargs.pop('labels')
lm_labels = output_ids[:, 1:].clone()
lm_labels[output_ids[:, 1:] == pad_token_id] = -100
outputs = self.bert(input_ids,
attention_mask=attention_mask, decoder_input_ids=y_ids, labels=lm_labels, )
return (outputs[0].detach(), generated_ids)
if task == "mlm" or task == "col_rev":
label_eos_id = kwargs.pop("label_eos_id")
label_bos_id = kwargs.pop("label_bos_id")
label_padding_id = kwargs.pop("label_padding_id")
generated_ids = self.bert.generate(
input_ids=input_ids,
attention_mask=attention_mask,
num_beams=3,
max_length=input_ids.size(1) + 5,
length_penalty=2.0,
early_stopping=True,
use_cache=True,
decoder_start_token_id=label_bos_id,
eos_token_id=label_eos_id,
pad_token_id=label_padding_id
)
generated_ids = generated_ids[:,1:].contiguous()
output_ids = kwargs.pop('labels')
y_ids = output_ids[:, :-1].contiguous()
lm_labels = output_ids[:, 1:].clone()
lm_labels[output_ids[:, 1:] == label_padding_id] = -100
outputs = self.bert(input_ids,
attention_mask=attention_mask, decoder_input_ids=y_ids, labels=lm_labels, )
return (outputs[0].detach(), generated_ids)
if task == "col_pred":
label_ids = kwargs.pop("labels")
column_spans = kwargs.pop("column_spans")
column_selection_prob = self.column_prediction(input_ids, attention_mask, column_spans)
generated_ids = (column_selection_prob.squeeze(-1) > 0.5).long()
generated_ids[column_spans[:,:,0]==0] = -100
label_mask = column_spans.view(-1, 2)[:, 0] > 0
column_selection_loss = F.binary_cross_entropy(column_selection_prob.view(-1)[label_mask],
label_ids.view(-1)[label_mask].float(),
reduction="sum") / label_ids.size(0)
# column_selection_loss = F.cross_entropy(column_selection_prob.view(-1, 3)[label_mask],
# label_ids.view(-1)[label_mask],
# reduction="sum") / label_ids.size(0)
return (column_selection_loss.detach(), generated_ids)
if task == "col_type":
label_ids = kwargs.pop("labels")
column_spans = kwargs.pop("column_spans")
column_selection_prob = self.column_prediction(input_ids, attention_mask, column_spans)
generated_ids = column_selection_prob.argmax(dim=-1)
generated_ids[column_spans[:, :, 0] == 0] = -100
label_mask = column_spans.view(-1, 2)[:, 0] > 0
column_selection_loss = F.cross_entropy(column_selection_prob.view(-1, 3)[label_mask],
label_ids.view(-1)[label_mask],
reduction="sum") / label_ids.size(0)
return (column_selection_loss.detach(), generated_ids)
if task == "value_pred":
label_ids = kwargs.pop("labels")
column_spans = kwargs.pop("column_spans")
value_spans = kwargs.pop("value_spans")
column_selection_prob = self.value_prediction(input_ids, attention_mask, column_spans, value_spans)
generated_ids = (column_selection_prob.squeeze(-1) > 0.5).long()
generated_ids[column_spans[:, :, 0] == 0] = -100
label_mask = column_spans.view(-1, 2)[:, 0] > 0
column_selection_loss = F.binary_cross_entropy(column_selection_prob.view(-1)[label_mask],
label_ids.view(-1)[label_mask].float(),
reduction="sum") / label_ids.size(0)
return (column_selection_loss.detach(), generated_ids)
if task == "table_pred":
label_ids = kwargs.pop("labels")
table_prediction_prob = self.table_pred(input_ids, attention_mask)
generated_ids = table_prediction_prob.argmax(dim=-1).unsqueeze(-1)
table_prediction_loss = F.cross_entropy(table_prediction_prob.view(-1, 2),
label_ids.view(-1),
reduction="sum") / label_ids.size(0)
return (table_prediction_loss.detach(), generated_ids)
raise NotImplementedError()