in rat-sql-gap/seq2struct/ast_util.py [0:0]
def verify_ast(self, node, expected_type=None, field_path=(), is_seq=False):
# type: (ASTWrapper, Node, Optional[str], Tuple[str, ...]) -> None
# pylint: disable=too-many-branches
'''Checks that `node` conforms to the current ASDL.'''
if node is None:
raise ValueError('node is None. path: {}'.format(field_path))
if not isinstance(node, dict):
raise ValueError('node is type {}. path: {}'.format(
type(node), field_path))
node_type = node['_type'] # type: str
if expected_type is not None:
sum_product = self.types[expected_type]
if isinstance(sum_product, asdl.Product):
if node_type != expected_type:
raise ValueError(
'Expected type {}, but instead saw {}. path: {}'.format(
expected_type, node_type, field_path))
elif isinstance(sum_product, asdl.Sum):
possible_names = [t.name
for t in sum_product.types] # type: List[str]
if is_seq:
possible_names += [t.name for t in getattr(sum_product, 'seq_fragment_types', [])]
if node_type not in possible_names:
raise ValueError(
'Expected one of {}, but instead saw {}. path: {}'.format(
', '.join(possible_names), node_type, field_path))
else:
raise ValueError('Unexpected type in ASDL: {}'.format(sum_product))
if node_type in self.types:
# Either a product or a sum type; we want it to be a product type
sum_product = self.types[node_type]
if isinstance(sum_product, asdl.Sum):
raise ValueError('sum type {} not allowed as node type. path: {}'.
format(node_type, field_path))
fields_to_check = sum_product.fields
elif node_type in self.constructors:
fields_to_check = self.constructors[node_type].fields
else:
raise ValueError('Unknown node_type {}. path: {}'.format(node_type,
field_path))
for field in fields_to_check:
# field.opt:
# - missing is okay
# field.seq
# - missing is okay
# - otherwise, must be list
if field.name not in node:
if field.opt or field.seq:
continue
raise ValueError('required field {} is missing. path: {}'.format(
field.name, field_path))
if field.seq and field.name in node and not isinstance(
node[field.name], (list, tuple)): # noqa: E125
raise ValueError('sequential field {} is not sequence. path: {}'.
format(field.name, field_path))
# Check that each item in this field has the expected type.
items = node.get(field.name,
()) if field.seq else (node.get(field.name), )
# pylint: disable=cell-var-from-loop
if field.type in self.primitive_type_checkers:
check = self.primitive_type_checkers[field.type]
else:
# pylint: disable=line-too-long
check = lambda n: self.verify_ast(n, field.type, field_path + (field.name, ), is_seq=field.seq) # noqa: E731,E501
for item in items:
assert check(item)
return True