def beam_search_with_heuristics()

in rat-sql-gap/seq2struct/models/spider/spider_beam_search.py [0:0]


def beam_search_with_heuristics(model, orig_item, preproc_item, beam_size, max_steps, from_cond=True):
    """
    Find the valid FROM clasue with beam search
    """
    inference_state, next_choices = model.begin_inference(orig_item, preproc_item)
    beam = [Hypothesis4Filtering(inference_state, next_choices)]

    cached_finished_seqs = [] # cache filtered trajectories
    beam_prefix = beam
    while True:
        # search prefixes with beam search
        prefixes2fill_from = []
        for step in range(max_steps):
            if len(prefixes2fill_from) >= beam_size:
                break

            candidates = []
            for hyp in beam_prefix:
                # print(hyp.inference_state.cur_item.state, hyp.inference_state.cur_item.node_type )
                if hyp.inference_state.cur_item.state == TreeTraversal.State.CHILDREN_APPLY \
                        and hyp.inference_state.cur_item.node_type == "from":
                    prefixes2fill_from.append(hyp)
                else:
                    candidates += [(hyp, choice, choice_score.item(),
                                hyp.score + choice_score.item())
                            for choice, choice_score in hyp.next_choices]
            candidates.sort(key=operator.itemgetter(3), reverse=True)
            candidates = candidates[:beam_size-len(prefixes2fill_from)]

            # Create the new hypotheses from the expansions
            beam_prefix = []
            for hyp, choice, choice_score, cum_score in candidates:
                inference_state = hyp.inference_state.clone()

                # cache column choice
                column_history = hyp.column_history[:]
                if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY and \
                        hyp.inference_state.cur_item.node_type == "column":
                    column_history = column_history + [choice]

                next_choices = inference_state.step(choice)
                assert next_choices is not None
                beam_prefix.append(
                    Hypothesis4Filtering(inference_state, next_choices, cum_score,
                            hyp.choice_history + [choice],
                            hyp.score_history + [choice_score],
                            column_history))

        prefixes2fill_from.sort(key=operator.attrgetter('score'), reverse=True)
        # assert len(prefixes) == beam_size

        # emuerating 
        beam_from = prefixes2fill_from
        max_size = 6
        unfiltered_finished = []
        prefixes_unfinished = []
        for step in range(max_steps):
            if len(unfiltered_finished) + len(prefixes_unfinished) > max_size:
                break

            candidates = []
            for hyp in beam_from:
                if step > 0 and hyp.inference_state.cur_item.state == TreeTraversal.State.CHILDREN_APPLY \
                        and hyp.inference_state.cur_item.node_type == "from":
                    prefixes_unfinished.append(hyp)
                else:
                    candidates += [(hyp, choice, choice_score.item(),
                            hyp.score + choice_score.item())
                        for choice, choice_score in hyp.next_choices]
            candidates.sort(key=operator.itemgetter(3), reverse=True)
            candidates = candidates[:max_size - len(prefixes_unfinished)]

            beam_from = []
            for hyp, choice, choice_score, cum_score in candidates:
                inference_state = hyp.inference_state.clone()

                # cache table choice
                table_history = hyp.table_history[:]
                key_column_history = hyp.key_column_history[:]
                if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY:
                    if hyp.inference_state.cur_item.node_type == "table":
                        table_history = table_history + [choice]
                    elif hyp.inference_state.cur_item.node_type == "column":
                        key_column_history = key_column_history + [choice]

                next_choices = inference_state.step(choice)
                if next_choices is None:
                    unfiltered_finished.append(Hypothesis4Filtering(
                        inference_state,
                        None,
                        cum_score,
                        hyp.choice_history + [choice],
                        hyp.score_history + [choice_score],
                        hyp.column_history, table_history,
                        key_column_history))
                else:
                    beam_from.append(
                        Hypothesis4Filtering(inference_state, next_choices, cum_score,
                                hyp.choice_history + [choice],
                                hyp.score_history + [choice_score],
                                hyp.column_history, table_history,
                                key_column_history))

        unfiltered_finished.sort(key=operator.attrgetter('score'), reverse=True)

        # filtering
        filtered_finished = []
        for hyp in unfiltered_finished:
            mentioned_column_ids = set(hyp.column_history)
            mentioned_key_column_ids = set(hyp.key_column_history)
            mentioned_table_ids = set(hyp.table_history)

            # duplicate tables
            if len(mentioned_table_ids) != len(hyp.table_history):
                continue

            # the foreign key should be correctly used
            # NOTE: the new version does not predict conditions in FROM clause anymore
            if from_cond:
                covered_tables = set()
                must_include_key_columns = set()
                candidate_table_ids = sorted(mentioned_table_ids)
                start_table_id = candidate_table_ids[0]
                for table_id in candidate_table_ids[1:]:
                    if table_id in covered_tables:
                        continue
                    try:
                        path = nx.shortest_path(
                            orig_item.schema.foreign_key_graph, source=start_table_id, target=table_id)
                    except (nx.NetworkXNoPath, nx.NodeNotFound):
                        covered_tables.add(table_id)
                        continue
                    
                    for source_table_id, target_table_id in zip(path, path[1:]):
                        if target_table_id in covered_tables:
                            continue
                        if target_table_id not in mentioned_table_ids:
                            continue
                        col1, col2 = orig_item.schema.foreign_key_graph[source_table_id][target_table_id]['columns']
                        must_include_key_columns.add(col1)
                        must_include_key_columns.add(col2)
                if not must_include_key_columns == mentioned_key_column_ids:
                    continue

            # tables whose columns are mentioned should also exist
            must_table_ids = set()
            for col in mentioned_column_ids:
                tab_ = orig_item.schema.columns[col].table
                if tab_ is not None:
                    must_table_ids.add(tab_.id)
            if not must_table_ids.issubset(mentioned_table_ids):
                continue
            
            filtered_finished.append(hyp)
         
        filtered_finished.sort(key=operator.attrgetter('score'), reverse=True)
        # filtered.sort(key=lambda x: x.score / len(x.choice_history), reverse=True)
        prefixes_unfinished.sort(key=operator.attrgetter('score'), reverse=True)
        # new_prefixes.sort(key=lambda x: x.score / len(x.choice_history), reverse=True)

        prefixes_, filtered_ = merge_beams(prefixes_unfinished, filtered_finished, beam_size)

        if filtered_:
            cached_finished_seqs = cached_finished_seqs + filtered_
            cached_finished_seqs.sort(key=operator.attrgetter('score'), reverse=True)

        if prefixes_ and len(prefixes_[0].choice_history) < 200:
            beam_prefix = prefixes_
            for hyp in beam_prefix:
                hyp.table_history = []
                hyp.column_history = []
                hyp.key_column_history = []
        elif cached_finished_seqs:
            return cached_finished_seqs[:beam_size]
        else:
            return unfiltered_finished[:beam_size]