in Models/exprsynth/nagdecoder.py [0:0]
def __load_expansiongraph_training_data_from_sample(
hyperparameters: Dict[str, Any], metadata: Dict[str, Any],
raw_sample: Dict[str, Any], prod_root_node: int, node_to_inherited_id: Dict[int, int],
node_to_synthesised_id: Dict[int, int], last_used_node_id: Dict[str, int],
result_holder: Dict[str, Any]) -> None:
# Shortcuts and temp data we use during construction:
symbol_to_kind = raw_sample['SymbolKinds'] # type: Dict[str, str]
symbol_to_prod = raw_sample['Productions'] # type: Dict[str, List[int]]
symbol_to_label = raw_sample['SymbolLabels'] # type: Dict[str, str]
variables_in_scope = list(sorted(raw_sample['LastUseOfVariablesInScope'].keys())) # type: List[str]
# These are the things we'll use in the end:
eg_node_id_to_prod_id = [] # type: List[Tuple[int, int]] # Pairs of (node id, chosen production id)
eg_node_id_to_varchoice = [] # type: List[Tuple[int, List[int], int]] # Triples of (node id, [id of last var use], index of correct var)
eg_node_id_to_literal_choice = defaultdict(list) # type: Dict[str, List[Tuple[int, int]]] # Maps literal kind to pairs of (node id, chosen literal id)
eg_prod_node_to_var_last_uses = {} # type: Dict[int, np.ndarray] # Dict from production node id to [id of last var use]
eg_schedule = [] # type: List[Dict[str, List[Tuple[int, int, Optional[int]]]]] # edge type name to edges of that type in each expansion step.
# We will use these to pick literals:
eg_literal_tok_to_idx = {}
if hyperparameters['eg_use_literal_copying']:
eg_literal_choice_normalizer_maps = {}
# For each literal kind, we compute a "normalizer map", which we use to identify
# choices that correspond to the same literal (e.g., when using a literal several
# times in context)
for literal_kind in LITERAL_NONTERMINALS:
# Collect all choices (vocab + things we can copy from):
literal_vocab = metadata['eg_literal_vocabs'][literal_kind]
literal_choices = \
literal_vocab.id_to_token \
+ result_holder['context_nonkeyword_tokens'][:hyperparameters['eg_max_context_tokens']]
first_tok_occurrences = {}
num_choices = hyperparameters['eg_max_context_tokens'] + len(literal_vocab)
normalizer_map = np.arange(num_choices, dtype=np.int16)
for (token_idx, token) in enumerate(literal_choices):
first_occ = first_tok_occurrences.get(token)
if first_occ is not None:
normalizer_map[token_idx] = first_occ
else:
first_tok_occurrences[token] = token_idx
eg_literal_tok_to_idx[literal_kind] = first_tok_occurrences
eg_literal_choice_normalizer_maps[literal_kind] = normalizer_map
result_holder['eg_literal_choice_normalizer_maps'] = eg_literal_choice_normalizer_maps
else:
for literal_kind in LITERAL_NONTERMINALS:
eg_literal_tok_to_idx[literal_kind] = metadata['eg_literal_vocabs'][literal_kind].token_to_id
# Map prods onto internal numbering and compute propagation schedule:
def declare_new_node(sym_exp_node_id: int, is_synthesised: bool) -> int:
new_node_id = len(result_holder['eg_node_labels'])
node_label = symbol_to_label.get(str(sym_exp_node_id)) or symbol_to_kind[str(sym_exp_node_id)]
result_holder['eg_node_labels'].append(metadata['eg_token_vocab'].get_id_or_unk(node_label))
if is_synthesised:
node_to_synthesised_id[sym_exp_node_id] = new_node_id
else:
node_to_inherited_id[sym_exp_node_id] = new_node_id
return new_node_id
def expand_node(node_id: int) -> None:
rhs_node_ids = symbol_to_prod.get(str(node_id))
if rhs_node_ids is None:
# In the case that we have no children, the downwards and upwards version of our node are the same:
node_to_synthesised_id[node_id] = node_to_inherited_id[node_id]
return
declare_new_node(node_id, is_synthesised=True)
# Figure out which production id this one is is and store it away:
rhs = raw_rhs_to_tuple(symbol_to_kind, symbol_to_label, rhs_node_ids)
known_lhs_productions = metadata['eg_production_vocab'][symbol_to_kind[str(node_id)]]
if rhs not in known_lhs_productions:
raise MissingProductionException("%s -> %s" % (symbol_to_kind[str(node_id)], rhs))
prod_id = known_lhs_productions[rhs]
eg_node_id_to_prod_id.append((node_to_inherited_id[node_id], prod_id))
if hyperparameters['eg_use_vars_for_production_choice']:
eg_prod_node_to_var_last_uses[node_to_inherited_id[node_id]] = \
np.array([node_to_synthesised_id[last_used_node_id[varchoice]] for varchoice in variables_in_scope], dtype=np.int16)
# print("Expanding %i using rule %i %s -> %s" % (node_to_inherited_id[node_id], prod_id,
# symbol_to_label.get(str(node_id)) or symbol_to_kind[str(node_id)],
# tuple([symbol_to_kind[str(rhs_node_id)] for rhs_node_id in rhs_node_ids])))
# Visit all children, in left-to-right order, descending into them if needed
last_sibling = None
parent_inwards_edges = defaultdict(list) # type: Dict[str, List[Tuple[int, int, Optional[int]]]]
parent_inwards_edges['InheritedToSynthesised'].append((node_to_inherited_id[node_id],
node_to_synthesised_id[node_id],
None))
for (rhs_symbol_idx, child_id) in enumerate(rhs_node_ids):
child_inherited_id = declare_new_node(child_id, is_synthesised=False)
child_inwards_edges = defaultdict(list) # type: Dict[str, List[Tuple[int, int, Optional[int]]]]
child_edge_label_id = metadata['eg_edge_label_vocab'][(prod_id, rhs_symbol_idx)]
child_inwards_edges['Child'].append((node_to_inherited_id[node_id], child_inherited_id, child_edge_label_id))
# Connection from left sibling, and prepare to be connected to the right sibling:
if last_sibling is not None:
child_inwards_edges['NextSibling'].append((node_to_synthesised_id[last_sibling], child_inherited_id, None))
last_sibling = child_id
# Connection from the last generated leaf ("next action" in "A syntactic neural model for general-purpose code generation", Yin & Neubig '17):
child_inwards_edges['NextSubtree'].append((node_to_synthesised_id[last_used_node_id[LAST_USED_TOKEN_NAME]],
child_inherited_id,
None))
# Check if we are terminal (token) node and add appropriate edges if that's the case:
if str(child_id) not in symbol_to_prod:
child_inwards_edges['NextToken'].append((node_to_synthesised_id[last_used_node_id[LAST_USED_TOKEN_NAME]],
child_inherited_id,
None))
last_used_node_id[LAST_USED_TOKEN_NAME] = child_id
# If we are a variable or literal, we also need to store information to train to make the right choice:
child_kind = symbol_to_kind[str(child_id)]
if child_kind == VARIABLE_NONTERMINAL:
var_name = symbol_to_label[str(child_id)]
# print(" Chose variable %s" % var_name)
last_var_use_id = last_used_node_id[var_name]
cur_variable_to_last_use_ids = [node_to_synthesised_id[last_used_node_id[varchoice]] for varchoice in variables_in_scope]
varchoice_id = variables_in_scope.index(var_name)
eg_node_id_to_varchoice.append((node_to_inherited_id[node_id], cur_variable_to_last_use_ids, varchoice_id))
child_inwards_edges['NextUse'].append((node_to_synthesised_id[last_var_use_id], child_inherited_id, None))
if hyperparameters['eg_update_last_variable_use_representation']:
last_used_node_id[var_name] = child_id
elif child_kind in LITERAL_NONTERMINALS:
literal = symbol_to_label[str(child_id)]
# print(" Chose literal %s" % literal)
literal_id = eg_literal_tok_to_idx[child_kind].get(literal)
# In the case that a literal is not in the vocab and not in the context, the above will return None,
# so map that explicitly to the id for UNK:
if literal_id is None:
literal_id = metadata['eg_literal_vocabs'][child_kind].get_id_or_unk(literal)
eg_node_id_to_literal_choice[child_kind].append((node_to_inherited_id[node_id], literal_id))
# Store the edges leading to new node, recurse into it, and mark its upwards connection for later:
eg_schedule.append(child_inwards_edges)
expand_node(child_id)
parent_inwards_edges['Parent'].append((node_to_synthesised_id[child_id], node_to_synthesised_id[node_id], None))
eg_schedule.append(parent_inwards_edges)
expand_node(prod_root_node)
expansion_labeled_edge_types, expansion_unlabeled_edge_types = get_restricted_edge_types(hyperparameters)
def split_schedule_step(step: Dict[str, List[Tuple[int, int, Optional[int]]]]) -> List[List[Tuple[int, int]]]:
total_edge_types = len(expansion_labeled_edge_types) + len(expansion_unlabeled_edge_types)
step_by_edge = [[] for _ in range(total_edge_types)] # type: List[List[Tuple[int, int]]]
for (label, edges) in step.items():
edges = [(v, w) for (v, w, _) in edges] # Strip off (optional) label:
if label in expansion_labeled_edge_types:
step_by_edge[expansion_labeled_edge_types[label]] = edges
elif label in expansion_unlabeled_edge_types:
step_by_edge[len(expansion_labeled_edge_types) + expansion_unlabeled_edge_types[label]] = edges
return step_by_edge
def edge_labels_from_schedule_step(step: Dict[str, List[Tuple[int, int, Optional[int]]]]) -> List[List[int]]:
labels_by_edge = [[] for _ in range(len(expansion_labeled_edge_types))] # type: List[List[int]]
for (label, edges) in step.items():
if label in expansion_labeled_edge_types:
label_ids = [l for (_, _, l) in edges] # Keep only edge label
labels_by_edge[expansion_labeled_edge_types[label]] = label_ids
return labels_by_edge
# print("Schedule:")
# initialised_nodes = set()
# initialised_nodes = initialised_nodes | result_holder['eg_node_id_to_cg_node_id'].keys()
# for step_id, expansion_step in enumerate(eg_schedule):
# print(" Step %i" % step_id)
# initialised_this_step = set()
# for edge_type in EXPANSION_UNLABELED_EDGE_TYPE_NAMES + EXPANSION_LABELED_EDGE_TYPE_NAMES:
# for (v, w, _) in expansion_step[edge_type]:
# assert v in initialised_nodes
# assert w not in initialised_nodes
# initialised_this_step.add(w)
# for newly_computed_node in initialised_this_step:
# node_label_id = result_holder['eg_node_labels'][newly_computed_node]
# print(" Node Label for %i: %i (reversed %s)"
# % (newly_computed_node, node_label_id, metadata['eg_token_vocab'].id_to_token[node_label_id]))
# for edge_type in EXPANSION_UNLABELED_EDGE_TYPE_NAMES + EXPANSION_LABELED_EDGE_TYPE_NAMES:
# edges = expansion_step[edge_type]
# if len(edges) > 0:
# initialised_nodes = initialised_nodes | initialised_this_step
# print(" %s edges: [%s]" % (edge_type,
# ", ".join("(%s -[%s]> %s)" % (v, l, w) for (v, w, l) in edges)))
# print("Variable choices:\n %s" % (str(eg_node_id_to_varchoice)))
# print("Literal choices: \n %s" % (str(eg_node_id_to_literal_choice)))
if hyperparameters['eg_use_vars_for_production_choice']:
result_holder['eg_production_node_id_to_var_last_use_node_ids'] = eg_prod_node_to_var_last_uses
result_holder['eg_node_id_to_prod_id'] = np.array(eg_node_id_to_prod_id, dtype=np.int16)
result_holder['eg_node_id_to_varchoice'] = eg_node_id_to_varchoice
result_holder['eg_node_id_to_literal_choice'] = {}
for literal_kind in LITERAL_NONTERMINALS:
literal_choice_data = eg_node_id_to_literal_choice.get(literal_kind)
if literal_choice_data is None:
literal_choice_data = np.empty(shape=[0, 2], dtype=np.uint16)
else:
literal_choice_data = np.array(literal_choice_data, dtype=np.uint16)
result_holder['eg_node_id_to_literal_choice'][literal_kind] = literal_choice_data
result_holder['eg_schedule'] = [split_schedule_step(step) for step in eg_schedule]
result_holder['eg_edge_label_ids'] = [edge_labels_from_schedule_step(step) for step in eg_schedule]