def beam_search_with_oracle_column()

in rat-sql-gap/seq2struct/models/spider/spider_beam_search.py [0:0]


def beam_search_with_oracle_column(model, orig_item, preproc_item, beam_size, max_steps, visualize_flag=False):
    inference_state, next_choices = model.begin_inference(orig_item, preproc_item)
    beam = [Hypothesis(inference_state, next_choices)]
    finished = []
    assert beam_size == 1

    # identify all the cols mentioned in the gold sql
    root_node = preproc_item[1].tree

    col_queue = list(reversed([val for val in model.decoder.ast_wrapper.find_all_descendants_of_type(root_node, "column")]))
    tab_queue = list(reversed([val for val in model.decoder.ast_wrapper.find_all_descendants_of_type(root_node, "table")]))
    col_queue_copy = col_queue[:]
    tab_queue_copy = tab_queue[:]

    predict_counter = 0

    for step in range(max_steps):
        if visualize_flag:
            print('step:')
            print(step)
        # Check if all beams are finished
        if len(finished) == beam_size:
            break
        
        # hijack the next choice using the gold col
        assert len(beam) == 1
        hyp = beam[0]
        if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY:
            if hyp.inference_state.cur_item.node_type == "column" \
                    and len(col_queue) > 0:
                gold_col = col_queue[0]

                flag = False
                for _choice in hyp.next_choices:
                    if _choice[0] == gold_col:
                        flag = True
                        hyp.next_choices = [_choice]
                        col_queue = col_queue[1:]
                        break
                assert flag
            elif hyp.inference_state.cur_item.node_type == "table" \
                    and len(tab_queue) > 0:
                gold_tab = tab_queue[0]

                flag = False
                for _choice in hyp.next_choices:
                    if _choice[0] == gold_tab:
                        flag = True
                        hyp.next_choices = [_choice]
                        tab_queue = tab_queue[1:]
                        break
                assert flag

        # for debug
        if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY:
            predict_counter += 1
            
        # For each hypothesis, get possible expansions
        # Score each expansion
        candidates = []
        for hyp in beam:
            candidates += [(hyp, choice, choice_score.item(),
                            hyp.score + choice_score.item())
                           for choice, choice_score in hyp.next_choices]

        # Keep the top K expansions
        candidates.sort(key=operator.itemgetter(3), reverse=True)
        candidates = candidates[:beam_size - len(finished)]


        # Create the new hypotheses from the expansions
        beam = []
        for hyp, choice, choice_score, cum_score in candidates:
            inference_state = hyp.inference_state.clone()
            next_choices = inference_state.step(choice)
            if next_choices is None:
                finished.append(Hypothesis(
                    inference_state,
                    None,
                    cum_score,
                    hyp.choice_history + [choice],
                    hyp.score_history + [choice_score]))
            else:
                beam.append(
                    Hypothesis(inference_state, next_choices, cum_score,
                               hyp.choice_history + [choice],
                               hyp.score_history + [choice_score]))
    if (len(col_queue_copy) + len(tab_queue_copy)) != predict_counter: 
        # print("The number of column/tables are not matched")
        pass
    finished.sort(key=operator.attrgetter('score'), reverse=True)
    return finished