in Sources/TensorFlow/Bindings/generate_wrappers.py [0:0]
def __init__(self, attr_def, op):
self.attr_def = attr_def
self.op = op
self.is_type_attr = attr_def.type in ['type', 'list(type)']
# Check whether the value of this attribute can be
# inferred automatically (this only applies to
# type-valued attributes).
input_args = list(op.op_def.input_arg)
output_args = list(op.op_def.output_arg)
input_arg_type_attrs = set(
[arg.type_attr for arg in input_args] +
[arg.type_list_attr for arg in input_args])
output_arg_type_attrs = set(
[arg.type_attr for arg in output_args] +
[arg.type_list_attr for arg in output_args])
arg_type_attrs = input_arg_type_attrs.union(output_arg_type_attrs)
self.is_inferred_type_attr = attr_def.name in arg_type_attrs
self.is_output_type_attr = attr_def.name in output_arg_type_attrs
self.is_func_attr = self.attr_def.type == 'func'
# We use this for obtaining the `_typeList` property.
self.input_arg = None
self.is_inferred_number_attr = False
for arg in self.op.input_args:
if self.attr_def.name in [arg.arg_def.type_attr,
arg.arg_def.type_list_attr] or \
self.attr_def.name == arg.arg_def.number_attr:
self.input_arg = arg
self.is_inferred_number_attr = True
break
# The following properties are only relevant for
# non-inferred-type-valued attributes.
self._swift_type = ''
self._use_enum = False
if not self.is_inferred_type_attr and not self.is_func_attr:
if self.attr_def.type not in _SWIFTIFIED_ATTR_TYPES:
raise UnableToGenerateCodeError(
'Unsupported type for attribute "%s".'
% self.attr_def.name)
# Get the arg type.
self._swift_type = _SWIFTIFIED_ATTR_TYPES[self.attr_def.type]
# Check if the arg is an enum type.
self._use_enum = False
if self.attr_def.type == 'string':
allowed_values = tuple(sorted(self.attr_def.allowed_values.list.s))
if allowed_values:
self._swift_type = self.op.enum_store.maybe_add(
allowed_values, self.attr_def.name)
self._use_enum = True
if self.is_func_attr:
input_type = self.swift_name.capitalize() + 'In'
output_type = self.swift_name.capitalize() + 'Out'
self._swift_type = '({}) -> {}'.format(input_type, output_type)