in rat-sql-gap/seq2struct/models/spider/spider_beam_search.py [0:0]
def beam_search_with_heuristics(model, orig_item, preproc_item, beam_size, max_steps, from_cond=True):
"""
Find the valid FROM clasue with beam search
"""
inference_state, next_choices = model.begin_inference(orig_item, preproc_item)
beam = [Hypothesis4Filtering(inference_state, next_choices)]
cached_finished_seqs = [] # cache filtered trajectories
beam_prefix = beam
while True:
# search prefixes with beam search
prefixes2fill_from = []
for step in range(max_steps):
if len(prefixes2fill_from) >= beam_size:
break
candidates = []
for hyp in beam_prefix:
# print(hyp.inference_state.cur_item.state, hyp.inference_state.cur_item.node_type )
if hyp.inference_state.cur_item.state == TreeTraversal.State.CHILDREN_APPLY \
and hyp.inference_state.cur_item.node_type == "from":
prefixes2fill_from.append(hyp)
else:
candidates += [(hyp, choice, choice_score.item(),
hyp.score + choice_score.item())
for choice, choice_score in hyp.next_choices]
candidates.sort(key=operator.itemgetter(3), reverse=True)
candidates = candidates[:beam_size-len(prefixes2fill_from)]
# Create the new hypotheses from the expansions
beam_prefix = []
for hyp, choice, choice_score, cum_score in candidates:
inference_state = hyp.inference_state.clone()
# cache column choice
column_history = hyp.column_history[:]
if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY and \
hyp.inference_state.cur_item.node_type == "column":
column_history = column_history + [choice]
next_choices = inference_state.step(choice)
assert next_choices is not None
beam_prefix.append(
Hypothesis4Filtering(inference_state, next_choices, cum_score,
hyp.choice_history + [choice],
hyp.score_history + [choice_score],
column_history))
prefixes2fill_from.sort(key=operator.attrgetter('score'), reverse=True)
# assert len(prefixes) == beam_size
# emuerating
beam_from = prefixes2fill_from
max_size = 6
unfiltered_finished = []
prefixes_unfinished = []
for step in range(max_steps):
if len(unfiltered_finished) + len(prefixes_unfinished) > max_size:
break
candidates = []
for hyp in beam_from:
if step > 0 and hyp.inference_state.cur_item.state == TreeTraversal.State.CHILDREN_APPLY \
and hyp.inference_state.cur_item.node_type == "from":
prefixes_unfinished.append(hyp)
else:
candidates += [(hyp, choice, choice_score.item(),
hyp.score + choice_score.item())
for choice, choice_score in hyp.next_choices]
candidates.sort(key=operator.itemgetter(3), reverse=True)
candidates = candidates[:max_size - len(prefixes_unfinished)]
beam_from = []
for hyp, choice, choice_score, cum_score in candidates:
inference_state = hyp.inference_state.clone()
# cache table choice
table_history = hyp.table_history[:]
key_column_history = hyp.key_column_history[:]
if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY:
if hyp.inference_state.cur_item.node_type == "table":
table_history = table_history + [choice]
elif hyp.inference_state.cur_item.node_type == "column":
key_column_history = key_column_history + [choice]
next_choices = inference_state.step(choice)
if next_choices is None:
unfiltered_finished.append(Hypothesis4Filtering(
inference_state,
None,
cum_score,
hyp.choice_history + [choice],
hyp.score_history + [choice_score],
hyp.column_history, table_history,
key_column_history))
else:
beam_from.append(
Hypothesis4Filtering(inference_state, next_choices, cum_score,
hyp.choice_history + [choice],
hyp.score_history + [choice_score],
hyp.column_history, table_history,
key_column_history))
unfiltered_finished.sort(key=operator.attrgetter('score'), reverse=True)
# filtering
filtered_finished = []
for hyp in unfiltered_finished:
mentioned_column_ids = set(hyp.column_history)
mentioned_key_column_ids = set(hyp.key_column_history)
mentioned_table_ids = set(hyp.table_history)
# duplicate tables
if len(mentioned_table_ids) != len(hyp.table_history):
continue
# the foreign key should be correctly used
# NOTE: the new version does not predict conditions in FROM clause anymore
if from_cond:
covered_tables = set()
must_include_key_columns = set()
candidate_table_ids = sorted(mentioned_table_ids)
start_table_id = candidate_table_ids[0]
for table_id in candidate_table_ids[1:]:
if table_id in covered_tables:
continue
try:
path = nx.shortest_path(
orig_item.schema.foreign_key_graph, source=start_table_id, target=table_id)
except (nx.NetworkXNoPath, nx.NodeNotFound):
covered_tables.add(table_id)
continue
for source_table_id, target_table_id in zip(path, path[1:]):
if target_table_id in covered_tables:
continue
if target_table_id not in mentioned_table_ids:
continue
col1, col2 = orig_item.schema.foreign_key_graph[source_table_id][target_table_id]['columns']
must_include_key_columns.add(col1)
must_include_key_columns.add(col2)
if not must_include_key_columns == mentioned_key_column_ids:
continue
# tables whose columns are mentioned should also exist
must_table_ids = set()
for col in mentioned_column_ids:
tab_ = orig_item.schema.columns[col].table
if tab_ is not None:
must_table_ids.add(tab_.id)
if not must_table_ids.issubset(mentioned_table_ids):
continue
filtered_finished.append(hyp)
filtered_finished.sort(key=operator.attrgetter('score'), reverse=True)
# filtered.sort(key=lambda x: x.score / len(x.choice_history), reverse=True)
prefixes_unfinished.sort(key=operator.attrgetter('score'), reverse=True)
# new_prefixes.sort(key=lambda x: x.score / len(x.choice_history), reverse=True)
prefixes_, filtered_ = merge_beams(prefixes_unfinished, filtered_finished, beam_size)
if filtered_:
cached_finished_seqs = cached_finished_seqs + filtered_
cached_finished_seqs.sort(key=operator.attrgetter('score'), reverse=True)
if prefixes_ and len(prefixes_[0].choice_history) < 200:
beam_prefix = prefixes_
for hyp in beam_prefix:
hyp.table_history = []
hyp.column_history = []
hyp.key_column_history = []
elif cached_finished_seqs:
return cached_finished_seqs[:beam_size]
else:
return unfiltered_finished[:beam_size]