def compute_mle_loss()

in rat-sql-gap/seq2struct/models/nl2code/decoder.py [0:0]


    def compute_mle_loss(self, enc_input, example, desc_enc, debug=False):
        traversal = TrainTreeTraversal(self, desc_enc, debug)
        traversal.step(None)
        queue = [
            TreeState(
                node=example.tree,
                parent_field_type=self.preproc.grammar.root_type,
            )
        ]
        while queue:
            item = queue.pop()
            node = item.node
            parent_field_type = item.parent_field_type

            if isinstance(node, (list, tuple)):
                node_type = parent_field_type + '*'
                rule = (node_type, len(node))
                rule_idx = self.rules_index[rule]
                assert traversal.cur_item.state == TreeTraversal.State.LIST_LENGTH_APPLY
                traversal.step(rule_idx)

                if self.preproc.use_seq_elem_rules and parent_field_type in self.ast_wrapper.sum_types:
                    parent_field_type += '_seq_elem'

                for i, elem in reversed(list(enumerate(node))):
                    queue.append(
                        TreeState(
                            node=elem,
                            parent_field_type=parent_field_type,
                        ))
                continue

            if parent_field_type in self.preproc.grammar.pointers:
                assert isinstance(node, int)
                assert traversal.cur_item.state == TreeTraversal.State.POINTER_APPLY
                pointer_map = desc_enc.pointer_maps.get(parent_field_type)
                if pointer_map:
                    values = pointer_map[node]
                    if self.sup_att == '1h':
                        if len(pointer_map) == len(enc_input['columns']):
                            if self.attn_type != 'sep':
                                traversal.step(values[0], values[1:], node + len(enc_input['question']))
                            else:
                                traversal.step(values[0], values[1:], node)
                        else:
                            if self.attn_type != 'sep':
                                traversal.step(values[0], values[1:], node + len(enc_input['question']) + len(enc_input['columns']))
                            else:
                                traversal.step(values[0], values[1:], node + len(enc_input['columns']))
                    else:
                        traversal.step(values[0], values[1:])
                else:
                    traversal.step(node)
                continue

            if parent_field_type in self.ast_wrapper.primitive_types:
                # identifier, int, string, bytes, object, singleton
                # - could be bytes, str, int, float, bool, NoneType
                # - terminal tokens vocabulary is created by turning everything into a string (with `str`)
                # - at decoding time, cast back to str/int/float/bool
                field_type = type(node).__name__
                field_value_split = self.preproc.grammar.tokenize_field_value(node) + [
                        vocab.EOS]

                for token in field_value_split:
                    assert traversal.cur_item.state == TreeTraversal.State.GEN_TOKEN
                    traversal.step(token)
                continue
            
            type_info = self.ast_wrapper.singular_types[node['_type']]

            if parent_field_type in self.preproc.sum_type_constructors:
                # ApplyRule, like expr -> Call
                rule = (parent_field_type, type_info.name)
                rule_idx = self.rules_index[rule]
                assert traversal.cur_item.state == TreeTraversal.State.SUM_TYPE_APPLY
                extra_rules = [
                    self.rules_index[parent_field_type, extra_type]
                    for extra_type in node.get('_extra_types', [])]
                traversal.step(rule_idx, extra_rules)

            if type_info.fields:
                # ApplyRule, like Call -> expr[func] expr*[args] keyword*[keywords]
                # Figure out which rule needs to be applied
                present = get_field_presence_info(self.ast_wrapper, node, type_info.fields)
                rule = (node['_type'], tuple(present))
                rule_idx = self.rules_index[rule]
                assert traversal.cur_item.state == TreeTraversal.State.CHILDREN_APPLY
                traversal.step(rule_idx)

            # reversed so that we perform a DFS in left-to-right order
            for field_info in reversed(type_info.fields):
                if field_info.name not in node:
                    continue

                queue.append(
                    TreeState(
                        node=node[field_info.name],
                        parent_field_type=field_info.type,
                    ))

        loss = torch.sum(torch.stack(tuple(traversal.loss), dim=0), dim=0)
        if debug:
            return loss, [attr.asdict(entry) for entry in traversal.history]
        else:
            return loss