in Sources/TensorFlow/Bindings/generate_wrappers.py [0:0]
def main(argv):
del argv # Unused.
if FLAGS.output_path is None:
raise ValueError('No output_path has been set')
api_def_map = c_api_util.ApiDefMap()
op_codes = []
op_codes_forwarding = []
enum_store = EnumStore()
op_names = api_def_map.op_names()
if FLAGS.api_def_path is not None:
for op_name in op_names:
path = os.path.join(FLAGS.api_def_path, 'api_def_%s.pbtxt' % op_name)
if not tf.gfile.Exists(path):
continue
with tf.gfile.Open(path, 'r') as fobj:
data = fobj.read()
try:
api_def_map.put_api_def(data)
except Exception as e:
print('Cannot load api def for %s: %s' % (op_name, str(e)))
num_generated = 0
for op_name in sorted(op_names):
try:
if op_name[0] == '_': continue
op_def = api_def_map.get_op_def(op_name)
if any(a.is_ref for a in op_def.input_arg):
raise UnableToGenerateCodeError('has ref-valued input')
if any(a.is_ref for a in op_def.output_arg):
raise UnableToGenerateCodeError('has ref-valued output')
api_def = api_def_map.get_api_def(bytes(op_name, 'utf8'))
# It would be nicer to handle `StringTensor` in a more
# general way by having `String` conform to `TensorFlowScalar`.
default_op = Op(op_def, api_def, enum_store, string_valued=False)
string_valued_op = Op(op_def, api_def, enum_store, string_valued=True)
default_code = default_op.swift_function()
string_valued_code = string_valued_op.swift_function()
op_codes.append(default_code)
string_valued_op_different = False
if string_valued_code != default_code:
string_valued_op_different = True
op_codes.append(string_valued_code)
default_code = default_op.swift_dispatch_function(x10_supported=op_name in X10_OPS)
string_valued_code = string_valued_op.swift_dispatch_function()
op_codes_forwarding.append(default_code)
if string_valued_op_different:
op_codes_forwarding.append(string_valued_code)
num_generated += 1
except UnableToGenerateCodeError as e:
print('Cannot generate code for %s: %s' % (op_name, e.details))
print('Generated code for %d/%d ops.' % (num_generated, len(op_names)))
version_codes = [
' static let generatedTensorFlowVersion = "%s"' % tf.__version__,
' static let generatedTensorFlowGitVersion = "%s"' % tf.__git_version__]
swift_code = (
_WARNING +
_HEADER +
'import CTensorFlow\n\n' +
'@inlinable @inline(__always)\n' +
'func makeOp(_ name: String, _ nOutputs: Int) -> TFTensorOperation {\n' +
' _ExecutionContext.makeOp(name, nOutputs)\n' +
'}\n'+
'\n\npublic enum _RawTFEager {\n\n' +
'\n'.join(version_codes) +
'\n\n' +
'\n\n'.join(enum_store.enum_codes()) +
'\n\n' +
'\n'.join(op_codes) +
'\n\n}\n')
with tf.gfile.Open(FLAGS.output_path, 'w') as f:
f.write(swift_code)
swift_code = (
_WARNING +
_HEADER +
_DISPATCHER_TEMPLATE.format(raw_dispatching_enum=
'public enum _Raw {\n\n' +
'\n'.join(version_codes) +
'\n\n' +
'\n\n'.join(enum_store.enum_codes_forwarding()) +
'\n\n' +
'\n'.join(op_codes_forwarding) + '\n\n}'))
if FLAGS.dispatching_output_path:
with tf.gfile.Open(FLAGS.dispatching_output_path, 'w') as f:
f.write(swift_code)