def random_graph_tensor()

in tensorflow_gnn/graph/graph_tensor_random.py [0:0]


def random_graph_tensor(
    spec: gt.GraphTensorSpec,
    sample_dict: Optional[SampleDict] = None,
    row_lengths_range: Tuple[int, int] = (2, 8),
    row_splits_dtype: tf.dtypes.DType = tf.int32,
    validate: bool = True) -> gt.GraphTensor:
  """Generate a graph tensor from a schema, with random features.

  Args:
    spec: A GraphTensorSpec instance that describes the graph tensor.
    sample_dict: A dict of (set-type, set-name, field-name) to list-of-values to
      sample from. The intended purpose is to generate random values that are
      more realistic, more representative of what the actual dataset will
      contain. You can provide such If the values aren't provided for a feature,
      random features are inserted of the right type.
    row_lengths_range: Minimum and maximum values for each row lengths in a
      ragged range.
    row_splits_dtype: Data type for row splits.
    validate: If true, then use assertions to check that the arguments form a
      valid RaggedTensor. Note: these assertions incur a runtime cost, since
      they must be checked for each tensor value.

  Returns:
    An instance of a GraphTensor.

  """
  if sample_dict is None:
    sample_dict = {}

  def _gen_features(set_type: gc.SetType,
                    set_name: Optional[gc.SetName],
                    features_spec: gc.Fields,
                    prefix: Optional[tf.Tensor]):
    """Generate a random feature tensor dict with a possible shape prefix."""
    tensors = {}
    for fname, feature_spec in features_spec.items():
      shape = feature_spec.shape.as_list()
      if prefix is not None and shape[0] is None:
        shape[0] = prefix
      key = (set_type, set_name, fname)
      sample_values = sample_dict.get(key, None)
      tensors[fname] = random_ragged_tensor(shape=shape,
                                            dtype=feature_spec.dtype,
                                            sample_values=sample_values,
                                            row_splits_dtype=row_splits_dtype,
                                            validate=validate)
    return tensors

  # Create random context features.
  context = gt.Context.from_fields(
      features=_gen_features(gc.CONTEXT, None,
                             spec.context_spec.features_spec, None))

  # Create random node-set features.
  min_nodes, max_nodes = row_lengths_range
  node_sets = {}
  for set_name, node_set_spec in spec.node_sets_spec.items():
    sizes = tf.random.uniform([1], min_nodes, max_nodes, row_splits_dtype)
    node_sets[set_name] = gt.NodeSet.from_fields(
        sizes=sizes,
        features=_gen_features(gc.NODES, set_name,
                               node_set_spec.features_spec, sizes[0]))

  # Create random edge-set features.
  edge_sets = {}
  for set_name, edge_set_spec in spec.edge_sets_spec.items():
    # Generate a reasonable number of edges.
    adj_spec = edge_set_spec.adjacency_spec
    source_size = node_sets[adj_spec.source_name].sizes
    target_size = node_sets[adj_spec.target_name].sizes
    sum_sizes = tf.cast(source_size[0] + target_size[0], tf.float32)
    min_edges = tf.cast(sum_sizes / 1.5, row_splits_dtype)
    max_edges = tf.cast(sum_sizes * 2.25, row_splits_dtype)
    sizes = tf.random.uniform([1], min_edges, max_edges, dtype=row_splits_dtype)

    # Generate a random matching.
    source_indices = tf.random.uniform(sizes, 0, source_size[0],
                                       dtype=row_splits_dtype)
    target_indices = tf.random.uniform(sizes, 0, target_size[0],
                                       dtype=row_splits_dtype)
    adjacency = adj.Adjacency.from_indices(
        source=(adj_spec.source_name, source_indices),
        target=(adj_spec.target_name, target_indices))

    # Create the edge set.
    edge_sets[set_name] = gt.EdgeSet.from_fields(
        sizes=sizes,
        features=_gen_features(gc.EDGES, set_name,
                               edge_set_spec.features_spec, sizes[0]),
        adjacency=adjacency)

  return gt.GraphTensor.from_pieces(context=context,
                                    node_sets=node_sets,
                                    edge_sets=edge_sets)