in rat-sql-gap/seq2struct/models/spider/spider_beam_search.py [0:0]
def beam_search_with_oracle_column(model, orig_item, preproc_item, beam_size, max_steps, visualize_flag=False):
inference_state, next_choices = model.begin_inference(orig_item, preproc_item)
beam = [Hypothesis(inference_state, next_choices)]
finished = []
assert beam_size == 1
# identify all the cols mentioned in the gold sql
root_node = preproc_item[1].tree
col_queue = list(reversed([val for val in model.decoder.ast_wrapper.find_all_descendants_of_type(root_node, "column")]))
tab_queue = list(reversed([val for val in model.decoder.ast_wrapper.find_all_descendants_of_type(root_node, "table")]))
col_queue_copy = col_queue[:]
tab_queue_copy = tab_queue[:]
predict_counter = 0
for step in range(max_steps):
if visualize_flag:
print('step:')
print(step)
# Check if all beams are finished
if len(finished) == beam_size:
break
# hijack the next choice using the gold col
assert len(beam) == 1
hyp = beam[0]
if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY:
if hyp.inference_state.cur_item.node_type == "column" \
and len(col_queue) > 0:
gold_col = col_queue[0]
flag = False
for _choice in hyp.next_choices:
if _choice[0] == gold_col:
flag = True
hyp.next_choices = [_choice]
col_queue = col_queue[1:]
break
assert flag
elif hyp.inference_state.cur_item.node_type == "table" \
and len(tab_queue) > 0:
gold_tab = tab_queue[0]
flag = False
for _choice in hyp.next_choices:
if _choice[0] == gold_tab:
flag = True
hyp.next_choices = [_choice]
tab_queue = tab_queue[1:]
break
assert flag
# for debug
if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY:
predict_counter += 1
# For each hypothesis, get possible expansions
# Score each expansion
candidates = []
for hyp in beam:
candidates += [(hyp, choice, choice_score.item(),
hyp.score + choice_score.item())
for choice, choice_score in hyp.next_choices]
# Keep the top K expansions
candidates.sort(key=operator.itemgetter(3), reverse=True)
candidates = candidates[:beam_size - len(finished)]
# Create the new hypotheses from the expansions
beam = []
for hyp, choice, choice_score, cum_score in candidates:
inference_state = hyp.inference_state.clone()
next_choices = inference_state.step(choice)
if next_choices is None:
finished.append(Hypothesis(
inference_state,
None,
cum_score,
hyp.choice_history + [choice],
hyp.score_history + [choice_score]))
else:
beam.append(
Hypothesis(inference_state, next_choices, cum_score,
hyp.choice_history + [choice],
hyp.score_history + [choice_score]))
if (len(col_queue_copy) + len(tab_queue_copy)) != predict_counter:
# print("The number of column/tables are not matched")
pass
finished.sort(key=operator.attrgetter('score'), reverse=True)
return finished