def generic_constraints()

in Sources/TensorFlow/Bindings/generate_wrappers.py [0:0]


  def generic_constraints(self, string_valued):
    # We use this for obtaining the `_typeList` property.
    input_arg = None
    if self.attr_def.type == 'list(type)':
      for arg in self.op.input_args:
        if self.attr_def.name in [arg.arg_def.type_attr,
                                  arg.arg_def.type_list_attr]:
          input_arg = arg
          break
    if self.is_func_attr:
      input_type = self.swift_name.capitalize() + 'In'
      output_type = self.swift_name.capitalize() + 'Out'
      return '{}: TensorGroup,\n    {}: TensorGroup'.format(
        input_type, output_type)
    if not self.is_inferred_type_attr:
      return None
    protocol = None
    if self.attr_def.type == 'list(type)' and input_arg is None:
      protocol = 'TensorGroup'
    elif self.attr_def.type == 'list(type)':
      protocol = 'TensorArrayProtocol'
    elif self.attr_def.type == 'type':
      if string_valued and self.allows_string:
        return None
      protocol = 'TensorFlowScalar'
      allowed_types = set(self.attr_def.allowed_values.list.type)
      allowed_types &= set(_SWIFTIFIED_TYPES.keys())
      for types, protocol_name in _TYPE_PROTOCOLS:
        if allowed_types.issubset(types):
          protocol = protocol_name
          break
    if protocol is not None:
      return self.swift_name + ': ' + protocol
    return None