in tensorflow_gnn/tools/print_training_data.py [0:0]
def app_main(_):
"""Read some graph tensor training subgraph examples and print them."""
schema = tfgnn.read_schema(FLAGS.graph_schema)
spec = tfgnn.create_graph_spec_from_schema_pb(schema)
# Read the input Example protos.
file_format = FLAGS.file_format or unigraph.guess_file_format(FLAGS.examples)
dataset = get_dataset(FLAGS.examples, file_format)
# Optionally batch the examples.
if FLAGS.batch_size and FLAGS.mode != 'textproto':
dataset = dataset.batch(FLAGS.batch_size)
parser = functools.partial(tfgnn.parse_example, spec)
else:
parser = functools.partial(tfgnn.parse_single_example, spec)
# Optionally cap the number of examples.
if FLAGS.num_examples:
dataset = dataset.take(FLAGS.num_examples)
# Pretty-format the values for each of the examples and print them.
if FLAGS.mode in {'python', 'json'}:
dataset = dataset.map(parser)
for graph in dataset:
graph_data = tfgnn.graph_tensor_to_values(graph)
if FLAGS.mode in 'json':
print(json.dumps(graph_data, sort_keys=True, indent=2))
else:
pprint.pprint(graph_data)
elif FLAGS.mode == 'textproto':
for example_str in dataset:
example = tf.train.Example()
example.ParseFromString(example_str.numpy())
print(example)