def _generate_elements()

in tfx_addons/sampling/executor.py [0:0]


def _generate_elements(example, label):
  """Function that fetches the class label from a tf.Example and returns one
  item in a K-V PCollection with the key as the label and the value as the
  string-parsed tf.Example.

  Args:
    example: a tf.Example in serialized format, taken directly from a
      TFRecordDataset.
    label: string containing the name of the categorical variable that we are
      extracting from the example.
  Returns:
    Tuple with two items. First item is a class label; second item is the input
      tf.Example, deserialized and parsed from string format.
  """

  class_label = None
  parsed = tf.train.Example.FromString(example.numpy())
  if parsed.features.feature[label].int64_list.value:
    val = parsed.features.feature[label].int64_list.value
    if len(val) > 0:
      class_label = val[0]
  else:
    val = parsed.features.feature[label].bytes_list.value
    if len(val) > 0:
      class_label = val[0].decode()
  return (class_label, parsed)