def encode_subgraph_to_example()

in tensorflow_gnn/sampler/subgraph.py [0:0]


def encode_subgraph_to_example(schema: gnn.GraphSchema,
                               subgraph: Subgraph) -> Example:
  """Convert a Subgraph to an encoded graph tensor."""

  # TODO(blais): Factor out the repeated bits from the static schema for reuse.

  # Copy context features. Nothing to aggregate here, simple copy.
  example = Example()
  for key, feature in schema.context.features.items():
    feature = subgraph.features.feature.get(key, None)
    if feature is not None:
      newkey = "context/{}".format(key)
      example.features.feature[newkey].CopyFrom(feature)

  # Prepare to store node and edge features.
  node_features_dicts: Dict[str, Dict[str, Feature]] = {}
  edge_features_dicts: Dict[str, Dict[str, Feature]] = {}
  for nset_name, nset_obj in schema.node_sets.items():
    node_features_dicts[nset_name] = _prepare_feature_dict(
        nset_name, nset_obj, "nodes", example)
  for eset_name, eset_obj in schema.edge_sets.items():
    edge_features_dicts[eset_name] = _prepare_feature_dict(
        eset_name, eset_obj, "edges", example)

  # Prepare to store edge indices.
  by_node_set_name: Dict[
      gnn.NodeSetName, List[subgraph_pb2.Node]] = collections.defaultdict(list)
  for node in subgraph.nodes:
    by_node_set_name[node.node_set_name].append(node)
  index_map: Dict[bytes, int] = {}
  for node_lists in by_node_set_name.values():
    index_map.update({node.id: i for i, node in enumerate(node_lists)})

  # Iterate over the nodes and edges.
  node_counter = collections.defaultdict(int)
  edge_counter = collections.defaultdict(int)
  for node in subgraph.nodes:
    node_counter[node.node_set_name] += 1
    node_features_dict = node_features_dicts[node.node_set_name]
    _copy_features(node.features, node_features_dict)

    # Iterate over outgoing edges.
    source_idx = index_map[node.id]
    for edge in node.outgoing_edges:
      # Append indices.
      target_idx = index_map.get(edge.neighbor_id, None)
      if target_idx is None:
        # Fail on edge references to a node that isn't in the graph.
        raise ValueError("Edge to node outside subgraph: '{}': {}".format(
            edge.neighbor_id, subgraph))

      source_name = "edges/{}.{}".format(edge.edge_set_name, gnn.SOURCE_NAME)
      target_name = "edges/{}.{}".format(edge.edge_set_name, gnn.TARGET_NAME)
      example.features.feature[source_name].int64_list.value.append(source_idx)
      example.features.feature[target_name].int64_list.value.append(target_idx)
      edge_counter[edge.edge_set_name] += 1

      # Store the edge features.
      edge_features_dict = edge_features_dicts[edge.edge_set_name]
      _copy_features(edge.features, edge_features_dict)

  # Produce size features.
  for node_set_name, num_nodes in node_counter.items():
    node_size_name = "nodes/{}.{}".format(node_set_name, gnn.SIZE_NAME)
    example.features.feature[node_size_name].int64_list.value.append(num_nodes)
  for edge_set_name, num_edges in edge_counter.items():
    edge_size_name = "edges/{}.{}".format(edge_set_name, gnn.SIZE_NAME)
    example.features.feature[edge_size_name].int64_list.value.append(num_edges)

  # Check the feature sizes (in aggregate).
  # TODO(blais): Support ragged features in this sampler eventually.
  for num_counter, features_dicts in [(node_counter, node_features_dicts),
                                      (edge_counter, edge_features_dicts)]:
    for set_name, features_dict in features_dicts.items():
      num = num_counter[set_name]
      for feature_name, out_feature in features_dict.items():
        out_length = get_feature_length(out_feature)
        if num > 0 and out_length % num != 0:
          raise ValueError(
              "Invalid number ({}) of features '{}' for set '{}' in subgraph '{}' for schema '{}'"
              .format(out_length, feature_name, set_name, subgraph, schema))

  strip_empty_features(example)
  return example