def app_main()

in tensorflow_gnn/tools/print_training_data.py [0:0]


def app_main(_):
  """Read some graph tensor training subgraph examples and print them."""

  schema = tfgnn.read_schema(FLAGS.graph_schema)
  spec = tfgnn.create_graph_spec_from_schema_pb(schema)

  # Read the input Example protos.
  file_format = FLAGS.file_format or unigraph.guess_file_format(FLAGS.examples)
  dataset = get_dataset(FLAGS.examples, file_format)

  # Optionally batch the examples.
  if FLAGS.batch_size and FLAGS.mode != 'textproto':
    dataset = dataset.batch(FLAGS.batch_size)
    parser = functools.partial(tfgnn.parse_example, spec)
  else:
    parser = functools.partial(tfgnn.parse_single_example, spec)

  # Optionally cap the number of examples.
  if FLAGS.num_examples:
    dataset = dataset.take(FLAGS.num_examples)

  # Pretty-format the values for each of the examples and print them.
  if FLAGS.mode in {'python', 'json'}:
    dataset = dataset.map(parser)
    for graph in dataset:
      graph_data = tfgnn.graph_tensor_to_values(graph)
      if FLAGS.mode in 'json':
        print(json.dumps(graph_data, sort_keys=True, indent=2))
      else:
        pprint.pprint(graph_data)
  elif FLAGS.mode == 'textproto':
    for example_str in dataset:
      example = tf.train.Example()
      example.ParseFromString(example_str.numpy())
      print(example)