in Sources/TensorFlow/Bindings/generate_wrappers.py [0:0]
def generic_constraints(self, string_valued):
# We use this for obtaining the `_typeList` property.
input_arg = None
if self.attr_def.type == 'list(type)':
for arg in self.op.input_args:
if self.attr_def.name in [arg.arg_def.type_attr,
arg.arg_def.type_list_attr]:
input_arg = arg
break
if self.is_func_attr:
input_type = self.swift_name.capitalize() + 'In'
output_type = self.swift_name.capitalize() + 'Out'
return '{}: TensorGroup,\n {}: TensorGroup'.format(
input_type, output_type)
if not self.is_inferred_type_attr:
return None
protocol = None
if self.attr_def.type == 'list(type)' and input_arg is None:
protocol = 'TensorGroup'
elif self.attr_def.type == 'list(type)':
protocol = 'TensorArrayProtocol'
elif self.attr_def.type == 'type':
if string_valued and self.allows_string:
return None
protocol = 'TensorFlowScalar'
allowed_types = set(self.attr_def.allowed_values.list.type)
allowed_types &= set(_SWIFTIFIED_TYPES.keys())
for types, protocol_name in _TYPE_PROTOCOLS:
if allowed_types.issubset(types):
protocol = protocol_name
break
if protocol is not None:
return self.swift_name + ': ' + protocol
return None