def beam_search_with_oracle_sketch()

in rat-sql-gap/seq2struct/models/spider/spider_beam_search.py [0:0]


def beam_search_with_oracle_sketch(model, orig_item, preproc_item, beam_size, max_steps, visualize_flag=False):
    inference_state, next_choices = model.begin_inference(orig_item, preproc_item)
    hyp = Hypothesis(inference_state, next_choices)

    parsed = model.decoder.preproc.grammar.parse(orig_item.code, "val")
    if not parsed:
        return []

    queue = [
        TreeState(
            node = preproc_item[1].tree,
            parent_field_type=model.decoder.preproc.grammar.root_type,
        )
    ]

    while queue:
        item = queue.pop()
        node = item.node
        parent_field_type = item.parent_field_type

        if isinstance(node, (list, tuple)):
            node_type = parent_field_type + '*'
            rule = (node_type, len(node))
            if rule not in  model.decoder.rules_index:
                return []
            rule_idx = model.decoder.rules_index[rule]
            assert inference_state.cur_item.state == TreeTraversal.State.LIST_LENGTH_APPLY
            next_choices = inference_state.step(rule_idx)

            if model.decoder.preproc.use_seq_elem_rules and \
                    parent_field_type in model.decoder.ast_wrapper.sum_types:
                parent_field_type += '_seq_elem'

            for i, elem in reversed(list(enumerate(node))):
                queue.append(
                    TreeState(
                        node=elem,
                        parent_field_type=parent_field_type,
                    ))

            hyp = Hypothesis(
                inference_state,
                None,
                0,
                hyp.choice_history + [rule_idx],
                hyp.score_history + [0])
            continue

        if parent_field_type in model.decoder.preproc.grammar.pointers:
            assert inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY
            # best_choice = max(next_choices, key=lambda x: x[1])
            # node = best_choice[0] # override the node

            assert isinstance(node, int)
            next_choices = inference_state.step(node)
            hyp = Hypothesis(
                inference_state,
                None,
                0,
                hyp.choice_history + [node],
                hyp.score_history + [0])
            continue

        if parent_field_type in model.decoder.ast_wrapper.primitive_types:
            field_value_split = model.decoder.preproc.grammar.tokenize_field_value(node) + [
                    '<EOS>']

            for token in field_value_split:
                next_choices = inference_state.step(token)
            hyp = Hypothesis(
                inference_state,
                None,
                0,
                hyp.choice_history + field_value_split,
                hyp.score_history + [0])
            continue
        
        type_info = model.decoder.ast_wrapper.singular_types[node['_type']]

        if parent_field_type in model.decoder.preproc.sum_type_constructors:
            # ApplyRule, like expr -> Call
            rule = (parent_field_type, type_info.name)
            rule_idx = model.decoder.rules_index[rule]
            inference_state.cur_item.state == TreeTraversal.State.SUM_TYPE_APPLY
            extra_rules = [
                model.decoder.rules_index[parent_field_type, extra_type]
                for extra_type in node.get('_extra_types', [])]
            next_choices = inference_state.step(rule_idx, extra_rules)

            hyp = Hypothesis(
                inference_state,
                None,
                0,
                hyp.choice_history + [rule_idx],
                hyp.score_history + [0])

        if type_info.fields:
            # ApplyRule, like Call -> expr[func] expr*[args] keyword*[keywords]
            # Figure out which rule needs to be applied
            present = get_field_presence_info(model.decoder.ast_wrapper, node, type_info.fields)
            rule = (node['_type'], tuple(present))
            rule_idx = model.decoder.rules_index[rule]
            next_choices = inference_state.step(rule_idx)

            hyp = Hypothesis(
                inference_state,
                None,
                0,
                hyp.choice_history + [rule_idx],
                hyp.score_history + [0])

        # reversed so that we perform a DFS in left-to-right order
        for field_info in reversed(type_info.fields):
            if field_info.name not in node:
                continue

            queue.append(
                TreeState(
                    node=node[field_info.name],
                    parent_field_type=field_info.type,
                ))

    return [hyp]