in rat-sql-gap/seq2struct/models/spider/spider_beam_search.py [0:0]
def beam_search_with_oracle_sketch(model, orig_item, preproc_item, beam_size, max_steps, visualize_flag=False):
inference_state, next_choices = model.begin_inference(orig_item, preproc_item)
hyp = Hypothesis(inference_state, next_choices)
parsed = model.decoder.preproc.grammar.parse(orig_item.code, "val")
if not parsed:
return []
queue = [
TreeState(
node = preproc_item[1].tree,
parent_field_type=model.decoder.preproc.grammar.root_type,
)
]
while queue:
item = queue.pop()
node = item.node
parent_field_type = item.parent_field_type
if isinstance(node, (list, tuple)):
node_type = parent_field_type + '*'
rule = (node_type, len(node))
if rule not in model.decoder.rules_index:
return []
rule_idx = model.decoder.rules_index[rule]
assert inference_state.cur_item.state == TreeTraversal.State.LIST_LENGTH_APPLY
next_choices = inference_state.step(rule_idx)
if model.decoder.preproc.use_seq_elem_rules and \
parent_field_type in model.decoder.ast_wrapper.sum_types:
parent_field_type += '_seq_elem'
for i, elem in reversed(list(enumerate(node))):
queue.append(
TreeState(
node=elem,
parent_field_type=parent_field_type,
))
hyp = Hypothesis(
inference_state,
None,
0,
hyp.choice_history + [rule_idx],
hyp.score_history + [0])
continue
if parent_field_type in model.decoder.preproc.grammar.pointers:
assert inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY
# best_choice = max(next_choices, key=lambda x: x[1])
# node = best_choice[0] # override the node
assert isinstance(node, int)
next_choices = inference_state.step(node)
hyp = Hypothesis(
inference_state,
None,
0,
hyp.choice_history + [node],
hyp.score_history + [0])
continue
if parent_field_type in model.decoder.ast_wrapper.primitive_types:
field_value_split = model.decoder.preproc.grammar.tokenize_field_value(node) + [
'<EOS>']
for token in field_value_split:
next_choices = inference_state.step(token)
hyp = Hypothesis(
inference_state,
None,
0,
hyp.choice_history + field_value_split,
hyp.score_history + [0])
continue
type_info = model.decoder.ast_wrapper.singular_types[node['_type']]
if parent_field_type in model.decoder.preproc.sum_type_constructors:
# ApplyRule, like expr -> Call
rule = (parent_field_type, type_info.name)
rule_idx = model.decoder.rules_index[rule]
inference_state.cur_item.state == TreeTraversal.State.SUM_TYPE_APPLY
extra_rules = [
model.decoder.rules_index[parent_field_type, extra_type]
for extra_type in node.get('_extra_types', [])]
next_choices = inference_state.step(rule_idx, extra_rules)
hyp = Hypothesis(
inference_state,
None,
0,
hyp.choice_history + [rule_idx],
hyp.score_history + [0])
if type_info.fields:
# ApplyRule, like Call -> expr[func] expr*[args] keyword*[keywords]
# Figure out which rule needs to be applied
present = get_field_presence_info(model.decoder.ast_wrapper, node, type_info.fields)
rule = (node['_type'], tuple(present))
rule_idx = model.decoder.rules_index[rule]
next_choices = inference_state.step(rule_idx)
hyp = Hypothesis(
inference_state,
None,
0,
hyp.choice_history + [rule_idx],
hyp.score_history + [0])
# reversed so that we perform a DFS in left-to-right order
for field_info in reversed(type_info.fields):
if field_info.name not in node:
continue
queue.append(
TreeState(
node=node[field_info.name],
parent_field_type=field_info.type,
))
return [hyp]