in rat-sql-gap/seq2struct/models/spider/spider_enc.py [0:0]
def forward(self, descs):
batch_token_lists = []
batch_id_to_retrieve_question = []
batch_id_to_retrieve_column = []
batch_id_to_retrieve_table = []
if self.summarize_header == "avg":
batch_id_to_retrieve_column_2 = []
batch_id_to_retrieve_table_2 = []
long_seq_set = set()
batch_id_map = {} # some long examples are not included
for batch_idx, desc in enumerate(descs):
qs = self.pad_single_sentence_for_bert(desc['question'], cls=True)
if self.use_column_type:
cols = [self.pad_single_sentence_for_bert(c, cls=False) for c in desc['columns']]
else:
cols = [self.pad_single_sentence_for_bert(c[:-1], cls=False) for c in desc['columns']]
tabs = [self.pad_single_sentence_for_bert(t, cls=False) for t in desc['tables']]
token_list = qs + [c for col in cols for c in col] + \
[t for tab in tabs for t in tab]
assert self.check_bert_seq(token_list)
if len(token_list) > 512:
long_seq_set.add(batch_idx)
continue
q_b = len(qs)
col_b = q_b + sum(len(c) for c in cols)
# leave out [CLS] and [SEP]
question_indexes = list(range(q_b))[1:-1]
# use the first representation for column/table
column_indexes = \
np.cumsum([q_b] + [len(token_list) for token_list in cols[:-1]]).tolist()
table_indexes = \
np.cumsum([col_b] + [len(token_list) for token_list in tabs[:-1]]).tolist()
if self.summarize_header == "avg":
column_indexes_2 = \
np.cumsum([q_b - 2] + [len(token_list) for token_list in cols]).tolist()[1:]
table_indexes_2 = \
np.cumsum([col_b - 2] + [len(token_list) for token_list in tabs]).tolist()[1:]
indexed_token_list = self.tokenizer.convert_tokens_to_ids(token_list)
batch_token_lists.append(indexed_token_list)
question_rep_ids = torch.LongTensor(question_indexes).to(self._device)
batch_id_to_retrieve_question.append(question_rep_ids)
column_rep_ids = torch.LongTensor(column_indexes).to(self._device)
batch_id_to_retrieve_column.append(column_rep_ids)
table_rep_ids = torch.LongTensor(table_indexes).to(self._device)
batch_id_to_retrieve_table.append(table_rep_ids)
if self.summarize_header == "avg":
assert (all(i2 >= i1 for i1, i2 in zip(column_indexes, column_indexes_2)))
column_rep_ids_2 = torch.LongTensor(column_indexes_2).to(self._device)
batch_id_to_retrieve_column_2.append(column_rep_ids_2)
assert (all(i2 >= i1 for i1, i2 in zip(table_indexes, table_indexes_2)))
table_rep_ids_2 = torch.LongTensor(table_indexes_2).to(self._device)
batch_id_to_retrieve_table_2.append(table_rep_ids_2)
batch_id_map[batch_idx] = len(batch_id_map)
padded_token_lists, att_mask_lists, tok_type_lists = self.pad_sequence_for_bert_batch(batch_token_lists)
tokens_tensor = torch.LongTensor(padded_token_lists).to(self._device)
att_masks_tensor = torch.LongTensor(att_mask_lists).to(self._device)
bert_output = self.bert_model(tokens_tensor,
attention_mask=att_masks_tensor)[0]
enc_output = bert_output
column_pointer_maps = [
{
i: [i]
for i in range(len(desc['columns']))
}
for desc in descs
]
table_pointer_maps = [
{
i: [i]
for i in range(len(desc['tables']))
}
for desc in descs
]
assert len(long_seq_set) == 0 # remove them for now
result = []
for batch_idx, desc in enumerate(descs):
c_boundary = list(range(len(desc["columns"]) + 1))
t_boundary = list(range(len(desc["tables"]) + 1))
if batch_idx in long_seq_set:
q_enc, col_enc, tab_enc = self.encoder_long_seq(desc)
else:
bert_batch_idx = batch_id_map[batch_idx]
q_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_question[bert_batch_idx]]
col_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_column[bert_batch_idx]]
tab_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_table[bert_batch_idx]]
if self.summarize_header == "avg":
col_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_column_2[bert_batch_idx]]
tab_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_table_2[bert_batch_idx]]
col_enc = (col_enc + col_enc_2) / 2.0 # avg of first and last token
tab_enc = (tab_enc + tab_enc_2) / 2.0 # avg of first and last token
assert q_enc.size()[0] == len(desc["question"])
assert col_enc.size()[0] == c_boundary[-1]
assert tab_enc.size()[0] == t_boundary[-1]
q_enc_new_item, c_enc_new_item, t_enc_new_item, align_mat_item = \
self.encs_update.forward_unbatched(
desc,
q_enc.unsqueeze(1),
col_enc.unsqueeze(1),
c_boundary,
tab_enc.unsqueeze(1),
t_boundary)
import pickle
pickle.dump({"desc": desc, "q_enc": q_enc, "col_enc": col_enc, "c_boundary": c_boundary, "tab_enc": tab_enc,
"t_boundary": t_boundary}, open("descs_{}.pkl".format(batch_idx), "wb"))
memory = []
if 'question' in self.include_in_memory:
memory.append(q_enc_new_item)
if 'column' in self.include_in_memory:
memory.append(c_enc_new_item)
if 'table' in self.include_in_memory:
memory.append(t_enc_new_item)
memory = torch.cat(memory, dim=1)
result.append(SpiderEncoderState(
state=None,
memory=memory,
question_memory=q_enc_new_item,
schema_memory=torch.cat((c_enc_new_item, t_enc_new_item), dim=1),
# TODO: words should match memory
words=desc['question'],
pointer_memories={
'column': c_enc_new_item,
'table': t_enc_new_item,
},
pointer_maps={
'column': column_pointer_maps[batch_idx],
'table': table_pointer_maps[batch_idx],
},
m2c_align_mat=align_mat_item[0],
m2t_align_mat=align_mat_item[1],
))
return result