def forward()

in rat-sql-gap/seq2struct/models/spider/spider_enc.py [0:0]


    def forward(self, descs):
        batch_token_lists = []
        batch_id_to_retrieve_question = []
        batch_id_to_retrieve_column = []
        batch_id_to_retrieve_table = []
        if self.summarize_header == "avg":
            batch_id_to_retrieve_column_2 = []
            batch_id_to_retrieve_table_2 = []
        long_seq_set = set()
        batch_id_map = {}  # some long examples are not included
        for batch_idx, desc in enumerate(descs):
            qs = self.pad_single_sentence_for_bert(desc['question'], cls=True)
            if self.use_column_type:
                cols = [self.pad_single_sentence_for_bert(c, cls=False) for c in desc['columns']]
            else:
                cols = [self.pad_single_sentence_for_bert(c[:-1], cls=False) for c in desc['columns']]
            tabs = [self.pad_single_sentence_for_bert(t, cls=False) for t in desc['tables']]

            token_list = qs + [c for col in cols for c in col] + \
                         [t for tab in tabs for t in tab]
            assert self.check_bert_seq(token_list)
            if len(token_list) > 512:
                long_seq_set.add(batch_idx)
                continue

            q_b = len(qs)
            col_b = q_b + sum(len(c) for c in cols)
            # leave out [CLS] and [SEP]
            question_indexes = list(range(q_b))[1:-1]
            # use the first representation for column/table
            column_indexes = \
                np.cumsum([q_b] + [len(token_list) for token_list in cols[:-1]]).tolist()
            table_indexes = \
                np.cumsum([col_b] + [len(token_list) for token_list in tabs[:-1]]).tolist()
            if self.summarize_header == "avg":
                column_indexes_2 = \
                    np.cumsum([q_b - 2] + [len(token_list) for token_list in cols]).tolist()[1:]
                table_indexes_2 = \
                    np.cumsum([col_b - 2] + [len(token_list) for token_list in tabs]).tolist()[1:]

            indexed_token_list = self.tokenizer.convert_tokens_to_ids(token_list)
            batch_token_lists.append(indexed_token_list)

            question_rep_ids = torch.LongTensor(question_indexes).to(self._device)
            batch_id_to_retrieve_question.append(question_rep_ids)
            column_rep_ids = torch.LongTensor(column_indexes).to(self._device)
            batch_id_to_retrieve_column.append(column_rep_ids)
            table_rep_ids = torch.LongTensor(table_indexes).to(self._device)
            batch_id_to_retrieve_table.append(table_rep_ids)
            if self.summarize_header == "avg":
                assert (all(i2 >= i1 for i1, i2 in zip(column_indexes, column_indexes_2)))
                column_rep_ids_2 = torch.LongTensor(column_indexes_2).to(self._device)
                batch_id_to_retrieve_column_2.append(column_rep_ids_2)
                assert (all(i2 >= i1 for i1, i2 in zip(table_indexes, table_indexes_2)))
                table_rep_ids_2 = torch.LongTensor(table_indexes_2).to(self._device)
                batch_id_to_retrieve_table_2.append(table_rep_ids_2)

            batch_id_map[batch_idx] = len(batch_id_map)

        padded_token_lists, att_mask_lists, tok_type_lists = self.pad_sequence_for_bert_batch(batch_token_lists)
        tokens_tensor = torch.LongTensor(padded_token_lists).to(self._device)
        att_masks_tensor = torch.LongTensor(att_mask_lists).to(self._device)


        bert_output = self.bert_model(tokens_tensor,
                                          attention_mask=att_masks_tensor)[0]

        enc_output = bert_output

        column_pointer_maps = [
            {
                i: [i]
                for i in range(len(desc['columns']))
            }
            for desc in descs
        ]
        table_pointer_maps = [
            {
                i: [i]
                for i in range(len(desc['tables']))
            }
            for desc in descs
        ]

        assert len(long_seq_set) == 0  # remove them for now

        result = []
        for batch_idx, desc in enumerate(descs):
            c_boundary = list(range(len(desc["columns"]) + 1))
            t_boundary = list(range(len(desc["tables"]) + 1))

            if batch_idx in long_seq_set:
                q_enc, col_enc, tab_enc = self.encoder_long_seq(desc)
            else:
                bert_batch_idx = batch_id_map[batch_idx]
                q_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_question[bert_batch_idx]]
                col_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_column[bert_batch_idx]]
                tab_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_table[bert_batch_idx]]

                if self.summarize_header == "avg":
                    col_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_column_2[bert_batch_idx]]
                    tab_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_table_2[bert_batch_idx]]

                    col_enc = (col_enc + col_enc_2) / 2.0  # avg of first and last token
                    tab_enc = (tab_enc + tab_enc_2) / 2.0  # avg of first and last token

            assert q_enc.size()[0] == len(desc["question"])
            assert col_enc.size()[0] == c_boundary[-1]
            assert tab_enc.size()[0] == t_boundary[-1]

            q_enc_new_item, c_enc_new_item, t_enc_new_item, align_mat_item = \
                self.encs_update.forward_unbatched(
                    desc,
                    q_enc.unsqueeze(1),
                    col_enc.unsqueeze(1),
                    c_boundary,
                    tab_enc.unsqueeze(1),
                    t_boundary)
            import pickle
            pickle.dump({"desc": desc, "q_enc": q_enc, "col_enc": col_enc, "c_boundary": c_boundary, "tab_enc": tab_enc,
                         "t_boundary": t_boundary}, open("descs_{}.pkl".format(batch_idx), "wb"))

            memory = []
            if 'question' in self.include_in_memory:
                memory.append(q_enc_new_item)
            if 'column' in self.include_in_memory:
                memory.append(c_enc_new_item)
            if 'table' in self.include_in_memory:
                memory.append(t_enc_new_item)
            memory = torch.cat(memory, dim=1)

            result.append(SpiderEncoderState(
                state=None,
                memory=memory,
                question_memory=q_enc_new_item,
                schema_memory=torch.cat((c_enc_new_item, t_enc_new_item), dim=1),
                # TODO: words should match memory
                words=desc['question'],
                pointer_memories={
                    'column': c_enc_new_item,
                    'table': t_enc_new_item,
                },
                pointer_maps={
                    'column': column_pointer_maps[batch_idx],
                    'table': table_pointer_maps[batch_idx],
                },
                m2c_align_mat=align_mat_item[0],
                m2t_align_mat=align_mat_item[1],
            ))
        return result