def _copy_feature_values()

in tensorflow_gnn/graph/graph_tensor_encode.py [0:0]


def _copy_feature_values(values: gc.Field, fname: str,
                         result: tf.train.Example):
  """Copy the values of an eager tensor to a `Feature` object."""

  # Flatten out the tensor to a rank-1 array.
  flat_values = values
  if isinstance(flat_values, tf.RaggedTensor):
    flat_values = values.flat_values
  flat_values = tf.reshape(flat_values, [-1])
  array = flat_values.numpy()

  # Convert the values to the proper type and set them.
  feature = result.features.feature[fname]
  if flat_values.dtype is tf.int32:
    flat_values = tf.cast(flat_values, tf.int64)
    feature.int64_list.value.extend(array)
  elif flat_values.dtype is tf.int64:
    feature.int64_list.value.extend(array)
  elif flat_values.dtype is tf.float32:
    feature.float_list.value.extend(array)
  elif flat_values.dtype is tf.float64:
    flat_values = tf.cast(flat_values, tf.float32)
    feature.float_list.value.extend(array)
  elif flat_values.dtype is tf.string:
    feature.bytes_list.value.extend(array)
  else:
    raise ValueError(f'Invalid type for tf.Example: {flat_values}')

  # If the tensor has ragged dimensions, serialize those into features to be
  # parsed as partitions.
  if isinstance(values, tf.RaggedTensor):
    iter_row_lengths = iter(values.nested_row_lengths())
    for i, dim in enumerate(values.shape.as_list()[1:], start=1):
      if dim is not None:
        continue
      row_lengths = next(iter_row_lengths)
      feature = result.features.feature[f'{fname}.d{i}']
      feature.int64_list.value.extend(row_lengths.numpy())