def forward()

in rat-sql-gap/seq2struct/models/spider/spider_enc.py [0:0]


    def forward(self, descs):
        # Encode the question
        # - LookupEmbeddings
        # - Transform embeddings wrt each other?

        # q_enc: PackedSequencePlus, [batch, question len, recurrent_size]
        qs = [[desc['question']] for desc in descs]
        q_enc, _ = self.question_encoder(qs)

        # Encode the columns
        # - LookupEmbeddings
        # - Transform embeddings wrt each other?
        # - Summarize each column into one?
        # c_enc: PackedSequencePlus, [batch, sum of column lens, recurrent_size]
        c_enc, c_boundaries = self.column_encoder([desc['columns'] for desc in descs])

        column_pointer_maps = [
            {
                i: list(range(left, right))
                for i, (left, right) in enumerate(zip(c_boundaries_for_item, c_boundaries_for_item[1:]))
            }
            for batch_idx, c_boundaries_for_item in enumerate(c_boundaries)
        ]

        # Encode the tables
        # - LookupEmbeddings
        # - Transform embeddings wrt each other?
        # - Summarize each table into one?
        # t_enc: PackedSequencePlus, [batch, sum of table lens, recurrent_size]
        t_enc, t_boundaries = self.table_encoder([desc['tables'] for desc in descs])

        # c_enc_lengths = list(c_enc.orig_lengths())
        # table_pointer_maps = [
        #     {
        #         i: [
        #             idx
        #             for col in desc['table_to_columns'][str(i)]
        #             for idx in column_pointer_maps[batch_idx][col]
        #         ] +  list(range(left + c_enc_lengths[batch_idx], right + c_enc_lengths[batch_idx]))
        #         for i, (left, right) in enumerate(zip(t_boundaries_for_item, t_boundaries_for_item[1:]))
        #     }
        #     for batch_idx, (desc, t_boundaries_for_item) in enumerate(zip(descs, t_boundaries))
        # ]

        # directly point to the the table
        table_pointer_maps = [
            {
                i: list(range(left, right))
                for i, (left, right) in enumerate(zip(t_boundaries_for_item, t_boundaries_for_item[1:]))
            }
            for batch_idx, (desc, t_boundaries_for_item) in enumerate(zip(descs, t_boundaries))
        ]

        # Update each other using self-attention
        # q_enc_new, c_enc_new, and t_enc_new are PackedSequencePlus with shape
        # batch (=1) x length x recurrent_size
        if self.batch_encs_update:
            q_enc_new, c_enc_new, t_enc_new = self.encs_update(
                    descs, q_enc, c_enc, c_boundaries, t_enc, t_boundaries)
 
        result = []
        for batch_idx, desc in enumerate(descs):
            if self.batch_encs_update:
                q_enc_new_item = q_enc_new.select(batch_idx).unsqueeze(0)
                c_enc_new_item = c_enc_new.select(batch_idx).unsqueeze(0)
                t_enc_new_item = t_enc_new.select(batch_idx).unsqueeze(0)
            else:
                q_enc_new_item, c_enc_new_item, t_enc_new_item, align_mat_item = \
                        self.encs_update.forward_unbatched(
                                desc,
                                q_enc.select(batch_idx).unsqueeze(1),
                                c_enc.select(batch_idx).unsqueeze(1),
                                c_boundaries[batch_idx],
                                t_enc.select(batch_idx).unsqueeze(1),
                                t_boundaries[batch_idx])

            memory = []
            words_for_copying = []
            if 'question' in self.include_in_memory:
                memory.append(q_enc_new_item)
                if 'question_for_copying' in desc:
                    assert q_enc_new_item.shape[1] == len(desc['question_for_copying'])
                    words_for_copying += desc['question_for_copying']
                else:
                    words_for_copying += [''] * q_enc_new_item.shape[1]
            if 'column' in self.include_in_memory:
                memory.append(c_enc_new_item)
                words_for_copying += [''] * c_enc_new_item.shape[1]
            if 'table' in self.include_in_memory:
                memory.append(t_enc_new_item)
                words_for_copying += [''] * t_enc_new_item.shape[1]
            memory = torch.cat(memory, dim=1)

            result.append(SpiderEncoderState(
                state=None,
                memory=memory,
                question_memory=q_enc_new_item,
                schema_memory=torch.cat((c_enc_new_item, t_enc_new_item), dim=1),
                # TODO: words should match memory
                words=words_for_copying,
                pointer_memories={
                    'column': c_enc_new_item,
                    'table': torch.cat((c_enc_new_item, t_enc_new_item), dim=1), 
                },
                pointer_maps={
                    'column': column_pointer_maps[batch_idx],
                    'table': table_pointer_maps[batch_idx],
                },
                m2c_align_mat=align_mat_item[0],
                m2t_align_mat=align_mat_item[1],
            ))
        return result