in tensorflow_gnn/graph/graph_tensor_encode.py [0:0]
def _copy_feature_values(values: gc.Field, fname: str,
result: tf.train.Example):
"""Copy the values of an eager tensor to a `Feature` object."""
# Flatten out the tensor to a rank-1 array.
flat_values = values
if isinstance(flat_values, tf.RaggedTensor):
flat_values = values.flat_values
flat_values = tf.reshape(flat_values, [-1])
array = flat_values.numpy()
# Convert the values to the proper type and set them.
feature = result.features.feature[fname]
if flat_values.dtype is tf.int32:
flat_values = tf.cast(flat_values, tf.int64)
feature.int64_list.value.extend(array)
elif flat_values.dtype is tf.int64:
feature.int64_list.value.extend(array)
elif flat_values.dtype is tf.float32:
feature.float_list.value.extend(array)
elif flat_values.dtype is tf.float64:
flat_values = tf.cast(flat_values, tf.float32)
feature.float_list.value.extend(array)
elif flat_values.dtype is tf.string:
feature.bytes_list.value.extend(array)
else:
raise ValueError(f'Invalid type for tf.Example: {flat_values}')
# If the tensor has ragged dimensions, serialize those into features to be
# parsed as partitions.
if isinstance(values, tf.RaggedTensor):
iter_row_lengths = iter(values.nested_row_lengths())
for i, dim in enumerate(values.shape.as_list()[1:], start=1):
if dim is not None:
continue
row_lengths = next(iter_row_lengths)
feature = result.features.feature[f'{fname}.d{i}']
feature.int64_list.value.extend(row_lengths.numpy())