in tensorflow_gnn/graph/graph_tensor_random.py [0:0]
def random_graph_tensor(
spec: gt.GraphTensorSpec,
sample_dict: Optional[SampleDict] = None,
row_lengths_range: Tuple[int, int] = (2, 8),
row_splits_dtype: tf.dtypes.DType = tf.int32,
validate: bool = True) -> gt.GraphTensor:
"""Generate a graph tensor from a schema, with random features.
Args:
spec: A GraphTensorSpec instance that describes the graph tensor.
sample_dict: A dict of (set-type, set-name, field-name) to list-of-values to
sample from. The intended purpose is to generate random values that are
more realistic, more representative of what the actual dataset will
contain. You can provide such If the values aren't provided for a feature,
random features are inserted of the right type.
row_lengths_range: Minimum and maximum values for each row lengths in a
ragged range.
row_splits_dtype: Data type for row splits.
validate: If true, then use assertions to check that the arguments form a
valid RaggedTensor. Note: these assertions incur a runtime cost, since
they must be checked for each tensor value.
Returns:
An instance of a GraphTensor.
"""
if sample_dict is None:
sample_dict = {}
def _gen_features(set_type: gc.SetType,
set_name: Optional[gc.SetName],
features_spec: gc.Fields,
prefix: Optional[tf.Tensor]):
"""Generate a random feature tensor dict with a possible shape prefix."""
tensors = {}
for fname, feature_spec in features_spec.items():
shape = feature_spec.shape.as_list()
if prefix is not None and shape[0] is None:
shape[0] = prefix
key = (set_type, set_name, fname)
sample_values = sample_dict.get(key, None)
tensors[fname] = random_ragged_tensor(shape=shape,
dtype=feature_spec.dtype,
sample_values=sample_values,
row_splits_dtype=row_splits_dtype,
validate=validate)
return tensors
# Create random context features.
context = gt.Context.from_fields(
features=_gen_features(gc.CONTEXT, None,
spec.context_spec.features_spec, None))
# Create random node-set features.
min_nodes, max_nodes = row_lengths_range
node_sets = {}
for set_name, node_set_spec in spec.node_sets_spec.items():
sizes = tf.random.uniform([1], min_nodes, max_nodes, row_splits_dtype)
node_sets[set_name] = gt.NodeSet.from_fields(
sizes=sizes,
features=_gen_features(gc.NODES, set_name,
node_set_spec.features_spec, sizes[0]))
# Create random edge-set features.
edge_sets = {}
for set_name, edge_set_spec in spec.edge_sets_spec.items():
# Generate a reasonable number of edges.
adj_spec = edge_set_spec.adjacency_spec
source_size = node_sets[adj_spec.source_name].sizes
target_size = node_sets[adj_spec.target_name].sizes
sum_sizes = tf.cast(source_size[0] + target_size[0], tf.float32)
min_edges = tf.cast(sum_sizes / 1.5, row_splits_dtype)
max_edges = tf.cast(sum_sizes * 2.25, row_splits_dtype)
sizes = tf.random.uniform([1], min_edges, max_edges, dtype=row_splits_dtype)
# Generate a random matching.
source_indices = tf.random.uniform(sizes, 0, source_size[0],
dtype=row_splits_dtype)
target_indices = tf.random.uniform(sizes, 0, target_size[0],
dtype=row_splits_dtype)
adjacency = adj.Adjacency.from_indices(
source=(adj_spec.source_name, source_indices),
target=(adj_spec.target_name, target_indices))
# Create the edge set.
edge_sets[set_name] = gt.EdgeSet.from_fields(
sizes=sizes,
features=_gen_features(gc.EDGES, set_name,
edge_set_spec.features_spec, sizes[0]),
adjacency=adjacency)
return gt.GraphTensor.from_pieces(context=context,
node_sets=node_sets,
edge_sets=edge_sets)