in Models/exprsynth/nagdecoder.py [0:0]
def generate_suggestions_for_one_sample(self,
test_sample: Dict[str, Any],
initial_eg_node_representations: tf.Tensor,
beam_size: int=3,
max_decoding_steps: int=100,
context_tokens: Optional[List[str]]=None,
context_token_representations: Optional[tf.Tensor]=None,
context_token_mask: Optional[np.ndarray]=None,
) -> ModelTestResult:
production_id_to_production = {} # type: Dict[int, Tuple[str, Iterable[str]]]
for (nonterminal, nonterminal_rules) in self.metadata['eg_production_vocab'].items():
for (expansion, prod_id) in nonterminal_rules.items():
production_id_to_production[prod_id] = (nonterminal, expansion)
max_used_eg_node_id = initial_eg_node_representations.shape[0]
def declare_new_node(expansion_info: ExpansionInformation, parent_node: int, node_type: str) -> int:
nonlocal max_used_eg_node_id
new_node_id = max_used_eg_node_id
new_synthesised_node_id = max_used_eg_node_id + 1
max_used_eg_node_id += 2
expansion_info.node_to_parent[new_node_id] = parent_node
expansion_info.node_to_children[parent_node].append(new_node_id)
expansion_info.node_to_type[new_node_id] = node_type
expansion_info.node_to_label[new_node_id] = node_type
expansion_info.node_to_label[new_synthesised_node_id] = node_type
expansion_info.node_to_synthesised_attr_node[new_node_id] = new_synthesised_node_id
expansion_info.node_to_inherited_attr_node[new_synthesised_node_id] = new_node_id
return new_node_id
def get_node_attributes(expansion_info: ExpansionInformation, node_id: int) -> tf.Tensor:
"""
Return attributes associated with node, either from cache in expansion information or by
calling the model to compute a representation according to the edge information in
the expansion_info.
"""
node_attributes = expansion_info.node_to_representation.get(node_id)
if node_attributes is None:
node_label = expansion_info.node_to_label[node_id]
if node_label == VARIABLE_NONTERMINAL:
node_label = ROOT_NONTERMINAL
if node_id not in expansion_info.node_to_unlabeled_incoming_edges:
self.compute_incoming_edges(self.metadata['eg_production_vocab'].keys(), expansion_info, node_id)
msg_prop_data = {self.placeholders['eg_msg_target_label_id']: self.metadata['eg_token_vocab'].get_id_or_unk(node_label)}
for labeled_edge_typ in self.__expansion_labeled_edge_types.keys():
source_node_ids = [v for (v, _) in expansion_info.node_to_labeled_incoming_edges[node_id][labeled_edge_typ]]
edge_labels = [l for (_, l) in expansion_info.node_to_labeled_incoming_edges[node_id][labeled_edge_typ]]
if len(source_node_ids) == 0:
sender_repr = np.empty(shape=[0, self.hyperparameters['eg_hidden_size']])
else:
sender_repr = [get_node_attributes(expansion_info, source_node_id) for source_node_id in source_node_ids]
labeled_edge_typ_idx = self.__expansion_labeled_edge_types[labeled_edge_typ]
msg_prop_data[self.placeholders['eg_msg_source_representations'][labeled_edge_typ_idx]] = sender_repr
msg_prop_data[self.placeholders['eg_edge_label_ids'][0][labeled_edge_typ_idx]] = edge_labels
for unlabeled_edge_typ in self.__expansion_unlabeled_edge_types.keys():
source_node_ids = expansion_info.node_to_unlabeled_incoming_edges[node_id][unlabeled_edge_typ]
if len(source_node_ids) == 0:
sender_repr = np.empty(shape=[0, self.hyperparameters['eg_hidden_size']])
else:
sender_repr = [get_node_attributes(expansion_info, source_node_id) for source_node_id in source_node_ids]
shifted_edge_type_id = len(self.__expansion_labeled_edge_types) + self.__expansion_unlabeled_edge_types[unlabeled_edge_typ]
msg_prop_data[self.placeholders['eg_msg_source_representations'][shifted_edge_type_id]] = sender_repr
# print("Computing attributes for %i (label %s) with following edges:" % (node_id, node_label))
# for labeled_edge_type in EXPANSION_LABELED_EDGE_TYPE_NAMES + EXPANSION_UNLABELED_EDGE_TYPE_NAMES:
# edges = expansion_info.node_to_labeled_incoming_edges[node_id][labeled_edge_type]
# if len(edges) > 0:
# print(" %s edges: [%s]" % (labeled_edge_type,
# ", ".join("(%s, %s)" % (w, node_id) for w in edges)))
node_attributes = self.sess.run(self.ops['eg_step_propagation_result'], feed_dict=msg_prop_data)
expansion_info.node_to_representation[node_id] = node_attributes
return node_attributes
def sample_productions(expansion_info: ExpansionInformation, node_to_expand: int) -> List[Tuple[Iterable[str], int, float]]:
prod_query_data = {}
write_to_minibatch(prod_query_data,
self.placeholders['eg_production_node_representation'],
get_node_attributes(expansion_info, node_to_expand))
if self.hyperparameters['eg_use_vars_for_production_choice']:
vars_in_scope = list(expansion_info.variable_to_last_use_id.keys())
vars_in_scope.remove(LAST_USED_TOKEN_NAME)
vars_in_scope_representations = [get_node_attributes(expansion_info, expansion_info.variable_to_last_use_id[var])
for var in vars_in_scope]
write_to_minibatch(prod_query_data,
self.placeholders['eg_production_var_representations'],
vars_in_scope_representations)
if self.hyperparameters['eg_use_literal_copying'] or self.hyperparameters['eg_use_context_attention']:
write_to_minibatch(prod_query_data,
self.placeholders['context_token_representations'],
expansion_info.context_token_representations)
write_to_minibatch(prod_query_data,
self.placeholders['context_token_mask'],
expansion_info.context_token_mask)
production_probs = self.sess.run(self.ops['eg_production_choice_probs'], feed_dict=prod_query_data)
result = []
# print("### Prod probs: %s" % (str(production_probs),))
for picked_production_index in pick_indices_from_probs(production_probs, beam_size):
prod_lhs, prod_rhs = production_id_to_production[picked_production_index]
# TODO: This should be ensured by appropriate masking in the model
if prod_lhs == expansion_info.node_to_type[node_to_expand]:
assert prod_lhs == expansion_info.node_to_type[node_to_expand]
result.append((prod_rhs, picked_production_index, production_probs[picked_production_index]))
return result
def sample_variable(expansion_info: ExpansionInformation, node_id: int) -> List[Tuple[str, float]]:
vars_in_scope = list(expansion_info.variable_to_last_use_id.keys())
vars_in_scope.remove(LAST_USED_TOKEN_NAME)
vars_in_scope_representations = [get_node_attributes(expansion_info, expansion_info.variable_to_last_use_id[var])
for var in vars_in_scope]
var_query_data = {self.placeholders['eg_num_variable_choices']: len(vars_in_scope)}
write_to_minibatch(var_query_data, self.placeholders['eg_varproduction_options_representations'], vars_in_scope_representations)
# We choose the variable name based on the information of the /parent/ node:
parent_node = expansion_info.node_to_parent[node_id]
write_to_minibatch(var_query_data, self.placeholders['eg_varproduction_node_representation'], get_node_attributes(expansion_info, parent_node))
var_probs = self.sess.run(self.ops['eg_varproduction_choice_probs'], feed_dict=var_query_data)
result = []
# print("### Var probs: %s" % (str(var_probs),))
for picked_var_index in pick_indices_from_probs(var_probs, beam_size):
result.append((vars_in_scope[picked_var_index], var_probs[picked_var_index]))
return result
def sample_literal(expansion_info: ExpansionInformation, node_id: int) -> List[Tuple[str, float]]:
literal_kind_to_sample = expansion_info.node_to_type[node_id]
lit_query_data = {}
# We choose the literal based on the information of the /parent/ node:
parent_node = expansion_info.node_to_parent[node_id]
write_to_minibatch(lit_query_data, self.placeholders['eg_litproduction_node_representation'], get_node_attributes(expansion_info, parent_node))
if self.hyperparameters["eg_use_literal_copying"]:
write_to_minibatch(lit_query_data,
self.placeholders['context_token_representations'],
expansion_info.context_token_representations)
write_to_minibatch(lit_query_data,
self.placeholders['context_token_mask'],
expansion_info.context_token_mask)
write_to_minibatch(lit_query_data,
self.placeholders['eg_litproduction_choice_normalizer'],
expansion_info.literal_production_choice_normalizer[literal_kind_to_sample])
lit_probs = self.sess.run(self.ops['eg_litproduction_choice_probs'][literal_kind_to_sample],
feed_dict=lit_query_data)
result = []
# print("### Var probs: %s" % (str(lit_probs),))
literal_vocab = self.metadata["eg_literal_vocabs"][literal_kind_to_sample]
literal_vocab_size = len(literal_vocab)
for picked_lit_index in pick_indices_from_probs(lit_probs, beam_size):
if picked_lit_index < literal_vocab_size:
result.append((literal_vocab.id_to_token[picked_lit_index], lit_probs[picked_lit_index]))
else:
result.append((expansion_info.context_tokens[picked_lit_index - literal_vocab_size], lit_probs[picked_lit_index]))
return result
def expand_node(expansion_info: ExpansionInformation) -> List[ExpansionInformation]:
if len(expansion_info.nodes_to_expand) == 0:
return [expansion_info]
if expansion_info.num_expansions > max_decoding_steps:
return []
node_to_expand = expansion_info.nodes_to_expand.popleft()
type_to_expand = expansion_info.node_to_type[node_to_expand]
expansions = []
if type_to_expand in self.metadata['eg_production_vocab']:
# Case production from grammar
for (prod_rhs, prod_id, prod_probability) in sample_productions(expansion_info, node_to_expand):
picked_rhs_expansion_info = clone_expansion_info(expansion_info, increment_expansion_counter=True)
picked_rhs_expansion_info.node_to_prod_id[node_to_expand] = prod_id
picked_rhs_expansion_info.expansion_logprob[0] = picked_rhs_expansion_info.expansion_logprob[0] + np.log(prod_probability)
# print("Expanding %i using rule %s -> %s with prob %.3f in %s (tree prob %.3f)."
# % (node_to_expand, type_to_expand, prod_rhs, prod_probability,
# " ".join(get_tokens_from_expansion(expansion_info, root_node)),
# np.exp(expansion_info.expansion_logprob[0])))
# Declare all children
for child_node_type in prod_rhs:
child_node_id = declare_new_node(picked_rhs_expansion_info, node_to_expand, child_node_type)
# print(" Child %i (type %s)" % (child_node_id, child_node_type))
# Mark the children as expansions. As we do depth-first, push them to front of the queue; and as
# extendleft reverses the order and we do left-to-right, reverse that reversal:
picked_rhs_expansion_info.nodes_to_expand.extendleft(reversed(picked_rhs_expansion_info.node_to_children[node_to_expand]))
expansions.append(picked_rhs_expansion_info)
elif type_to_expand == VARIABLE_NONTERMINAL:
# Case choose variable name.
if len(expansion_info.variable_to_last_use_id.keys()) > 1: # Only continue if at least one var is in scope (not just LAST_USED_TOKEN_NAME)
for (child_label, var_probability) in sample_variable(expansion_info, node_to_expand):
# print("Expanding %i by using variable %s with prob %.3f in %s (tree prob %.3f)."
# % (node_to_expand, child_label, var_probability,
# " ".join(get_tokens_from_expansion(expansion_info, root_node)),
# np.exp(expansion_info.expansion_logprob[0])))
child_expansion_info = clone_expansion_info(expansion_info)
child_expansion_info.node_to_synthesised_attr_node[node_to_expand] = node_to_expand # synthesised and inherited are the same for leafs
child_expansion_info.node_to_label[node_to_expand] = child_label
self.compute_incoming_edges(self.metadata['eg_production_vocab'].keys(), child_expansion_info, node_to_expand) # This needs to be done now before we update the variable-to-last-use info
child_expansion_info.expansion_logprob[0] = child_expansion_info.expansion_logprob[0] + np.log(var_probability)
if self.hyperparameters['eg_update_last_variable_use_representation']:
child_expansion_info.variable_to_last_use_id[child_label] = node_to_expand
child_expansion_info.variable_to_last_use_id[LAST_USED_TOKEN_NAME] = node_to_expand
expansions.append(child_expansion_info)
elif type_to_expand in LITERAL_NONTERMINALS:
for (picked_literal, literal_probability) in sample_literal(expansion_info, node_to_expand):
# print("Expanding %i by using literal %s with prob %.3f in %s (tree prob %.3f)."
# % (node_to_expand, picked_literal, literal_probability,
# " ".join(get_tokens_from_expansion(expansion_info, root_node)),
# np.exp(expansion_info.expansion_logprob[0])))
picked_literal_expansion_info = clone_expansion_info(expansion_info)
picked_literal_expansion_info.node_to_synthesised_attr_node[node_to_expand] = node_to_expand # synthesised and inherited are the same for leafs
picked_literal_expansion_info.node_to_label[node_to_expand] = picked_literal
self.compute_incoming_edges(self.metadata['eg_production_vocab'].keys(), picked_literal_expansion_info,
node_to_expand)
picked_literal_expansion_info.expansion_logprob[0] = picked_literal_expansion_info.expansion_logprob[0] + np.log(literal_probability)
picked_literal_expansion_info.variable_to_last_use_id[LAST_USED_TOKEN_NAME] = node_to_expand
expansions.append(picked_literal_expansion_info)
else:
# Case node is a terminal: Do nothing
# print("Handling leaf node %i (label %s)" % (node_to_expand, expansion_info.node_to_label[node_to_expand]))
expansion_info.node_to_synthesised_attr_node[node_to_expand] = node_to_expand # synthesised and inherited are the same for leafs
self.compute_incoming_edges(self.metadata['eg_production_vocab'].keys(), expansion_info, node_to_expand) # This needs to be done now before we update the variable-to-last-use info
expansion_info.variable_to_last_use_id[LAST_USED_TOKEN_NAME] = node_to_expand
expansions = [expansion_info]
return expansions
if self.hyperparameters['eg_use_literal_copying']:
literal_production_choice_normalizer = {}
for literal_kind in LITERAL_NONTERMINALS:
# Collect all choices (vocab + things we can copy from):
literal_vocab = self.metadata['eg_literal_vocabs'][literal_kind]
literal_choices = literal_vocab.id_to_token + context_tokens
first_tok_occurrences = {}
num_choices = len(literal_vocab) + self.hyperparameters['eg_max_context_tokens']
normalizer_map = np.arange(num_choices, dtype=np.int16)
for (token_idx, token) in enumerate(literal_choices):
first_occ = first_tok_occurrences.get(token)
if first_occ is not None:
normalizer_map[token_idx] = first_occ
else:
first_tok_occurrences[token] = token_idx
literal_production_choice_normalizer[literal_kind] = normalizer_map
else:
literal_production_choice_normalizer = None
root_node = test_sample['eg_root_node']
initial_variable_to_last_use_id = test_sample['eg_variable_eg_node_ids']
initial_variable_to_last_use_id[LAST_USED_TOKEN_NAME] = test_sample['eg_last_token_eg_node_id']
initial_node_to_representation = {node_id: initial_eg_node_representations[node_id]
for node_id in initial_variable_to_last_use_id.values()}
initial_node_to_representation[root_node] = initial_eg_node_representations[root_node]
initial_info = ExpansionInformation(node_to_type={root_node: ROOT_NONTERMINAL},
node_to_label={root_node: ROOT_NONTERMINAL},
node_to_prod_id={},
node_to_children=defaultdict(list),
node_to_parent={},
node_to_synthesised_attr_node={node_id: node_id for node_id in initial_node_to_representation.keys()},
node_to_inherited_attr_node={},
variable_to_last_use_id=initial_variable_to_last_use_id,
node_to_representation=initial_node_to_representation,
node_to_labeled_incoming_edges={root_node: defaultdict(list)},
node_to_unlabeled_incoming_edges={root_node: defaultdict(list)},
context_token_representations=context_token_representations,
context_token_mask=context_token_mask,
context_tokens=context_tokens,
literal_production_choice_normalizer=literal_production_choice_normalizer,
nodes_to_expand=deque([root_node]),
expansion_logprob=[0.0],
num_expansions=0)
beams = [initial_info]
while any(len(b.nodes_to_expand) > 0 for b in beams):
new_beams = [new_beam
for beam in beams
for new_beam in expand_node(beam)]
beams = sorted(new_beams, key=lambda b: -b.expansion_logprob[0])[:beam_size] # Pick top K beams
self.test_log("Groundtruth: %s" % (" ".join(test_sample['eg_tokens']),))
all_predictions = [] # type: List[Tuple[List[str], float]]
for (k, beam_info) in enumerate(beams):
kth_result = get_tokens_from_expansion(beam_info, root_node)
all_predictions.append((kth_result, np.exp(beam_info.expansion_logprob[0])))
self.test_log(" @%i Prob. %.3f: %s" % (k+1, np.exp(beam_info.expansion_logprob[0]), " ".join(kth_result)))
if len(beams) == 0:
self.test_log("No beams finished!")
return ModelTestResult(test_sample['eg_tokens'], all_predictions)