def main()

in Sources/TensorFlow/Bindings/generate_wrappers.py [0:0]


def main(argv):
  del argv  # Unused.
  if FLAGS.output_path is None:
    raise ValueError('No output_path has been set')

  api_def_map = c_api_util.ApiDefMap()

  op_codes = []
  op_codes_forwarding = []
  enum_store = EnumStore()
  op_names = api_def_map.op_names()
  if FLAGS.api_def_path is not None:
    for op_name in op_names:
      path = os.path.join(FLAGS.api_def_path, 'api_def_%s.pbtxt' % op_name)
      if not tf.gfile.Exists(path):
        continue
      with tf.gfile.Open(path, 'r') as fobj:
        data = fobj.read()
      try:
        api_def_map.put_api_def(data)
      except Exception as e:
        print('Cannot load api def for %s: %s' % (op_name, str(e)))

  num_generated = 0
  for op_name in sorted(op_names):
    try:
      if op_name[0] == '_': continue
      op_def = api_def_map.get_op_def(op_name)
      if any(a.is_ref for a in op_def.input_arg):
        raise UnableToGenerateCodeError('has ref-valued input')
      if any(a.is_ref for a in op_def.output_arg):
        raise UnableToGenerateCodeError('has ref-valued output')
      api_def = api_def_map.get_api_def(bytes(op_name, 'utf8'))

      # It would be nicer to handle `StringTensor` in a more
      # general way by having `String` conform to `TensorFlowScalar`.
      default_op = Op(op_def, api_def, enum_store, string_valued=False)
      string_valued_op = Op(op_def, api_def, enum_store, string_valued=True)
      default_code = default_op.swift_function()
      string_valued_code = string_valued_op.swift_function()
      op_codes.append(default_code)
      string_valued_op_different = False
      if string_valued_code != default_code:
        string_valued_op_different = True
        op_codes.append(string_valued_code)

      default_code = default_op.swift_dispatch_function(x10_supported=op_name in X10_OPS)
      string_valued_code = string_valued_op.swift_dispatch_function()
      op_codes_forwarding.append(default_code)
      if string_valued_op_different:
        op_codes_forwarding.append(string_valued_code)

      num_generated += 1
    except UnableToGenerateCodeError as e:
      print('Cannot generate code for %s: %s' % (op_name, e.details))
  print('Generated code for %d/%d ops.' % (num_generated, len(op_names)))

  version_codes = [
      '  static let generatedTensorFlowVersion = "%s"' % tf.__version__,
      '  static let generatedTensorFlowGitVersion = "%s"' % tf.__git_version__]

  swift_code = (
      _WARNING +
      _HEADER +
      'import CTensorFlow\n\n' +
      '@inlinable @inline(__always)\n' +
      'func makeOp(_ name: String, _ nOutputs: Int) -> TFTensorOperation {\n' +
      '  _ExecutionContext.makeOp(name, nOutputs)\n' +
      '}\n'+
      '\n\npublic enum _RawTFEager {\n\n' +
      '\n'.join(version_codes) +
      '\n\n' +
      '\n\n'.join(enum_store.enum_codes()) +
      '\n\n' +
      '\n'.join(op_codes) +
      '\n\n}\n')
  with tf.gfile.Open(FLAGS.output_path, 'w') as f:
    f.write(swift_code)

  swift_code = (
      _WARNING +
      _HEADER +
      _DISPATCHER_TEMPLATE.format(raw_dispatching_enum=
      'public enum _Raw {\n\n' +
      '\n'.join(version_codes) +
      '\n\n' +
      '\n\n'.join(enum_store.enum_codes_forwarding()) +
      '\n\n' +
      '\n'.join(op_codes_forwarding) + '\n\n}'))
  if FLAGS.dispatching_output_path:
    with tf.gfile.Open(FLAGS.dispatching_output_path, 'w') as f:
      f.write(swift_code)