in rat-sql-gap/seq2struct/models/nl2code/decoder.py [0:0]
def compute_mle_loss(self, enc_input, example, desc_enc, debug=False):
traversal = TrainTreeTraversal(self, desc_enc, debug)
traversal.step(None)
queue = [
TreeState(
node=example.tree,
parent_field_type=self.preproc.grammar.root_type,
)
]
while queue:
item = queue.pop()
node = item.node
parent_field_type = item.parent_field_type
if isinstance(node, (list, tuple)):
node_type = parent_field_type + '*'
rule = (node_type, len(node))
rule_idx = self.rules_index[rule]
assert traversal.cur_item.state == TreeTraversal.State.LIST_LENGTH_APPLY
traversal.step(rule_idx)
if self.preproc.use_seq_elem_rules and parent_field_type in self.ast_wrapper.sum_types:
parent_field_type += '_seq_elem'
for i, elem in reversed(list(enumerate(node))):
queue.append(
TreeState(
node=elem,
parent_field_type=parent_field_type,
))
continue
if parent_field_type in self.preproc.grammar.pointers:
assert isinstance(node, int)
assert traversal.cur_item.state == TreeTraversal.State.POINTER_APPLY
pointer_map = desc_enc.pointer_maps.get(parent_field_type)
if pointer_map:
values = pointer_map[node]
if self.sup_att == '1h':
if len(pointer_map) == len(enc_input['columns']):
if self.attn_type != 'sep':
traversal.step(values[0], values[1:], node + len(enc_input['question']))
else:
traversal.step(values[0], values[1:], node)
else:
if self.attn_type != 'sep':
traversal.step(values[0], values[1:], node + len(enc_input['question']) + len(enc_input['columns']))
else:
traversal.step(values[0], values[1:], node + len(enc_input['columns']))
else:
traversal.step(values[0], values[1:])
else:
traversal.step(node)
continue
if parent_field_type in self.ast_wrapper.primitive_types:
# identifier, int, string, bytes, object, singleton
# - could be bytes, str, int, float, bool, NoneType
# - terminal tokens vocabulary is created by turning everything into a string (with `str`)
# - at decoding time, cast back to str/int/float/bool
field_type = type(node).__name__
field_value_split = self.preproc.grammar.tokenize_field_value(node) + [
vocab.EOS]
for token in field_value_split:
assert traversal.cur_item.state == TreeTraversal.State.GEN_TOKEN
traversal.step(token)
continue
type_info = self.ast_wrapper.singular_types[node['_type']]
if parent_field_type in self.preproc.sum_type_constructors:
# ApplyRule, like expr -> Call
rule = (parent_field_type, type_info.name)
rule_idx = self.rules_index[rule]
assert traversal.cur_item.state == TreeTraversal.State.SUM_TYPE_APPLY
extra_rules = [
self.rules_index[parent_field_type, extra_type]
for extra_type in node.get('_extra_types', [])]
traversal.step(rule_idx, extra_rules)
if type_info.fields:
# ApplyRule, like Call -> expr[func] expr*[args] keyword*[keywords]
# Figure out which rule needs to be applied
present = get_field_presence_info(self.ast_wrapper, node, type_info.fields)
rule = (node['_type'], tuple(present))
rule_idx = self.rules_index[rule]
assert traversal.cur_item.state == TreeTraversal.State.CHILDREN_APPLY
traversal.step(rule_idx)
# reversed so that we perform a DFS in left-to-right order
for field_info in reversed(type_info.fields):
if field_info.name not in node:
continue
queue.append(
TreeState(
node=node[field_info.name],
parent_field_type=field_info.type,
))
loss = torch.sum(torch.stack(tuple(traversal.loss), dim=0), dim=0)
if debug:
return loss, [attr.asdict(entry) for entry in traversal.history]
else:
return loss