in tfx_addons/sampling/executor.py [0:0]
def _generate_elements(example, label):
"""Function that fetches the class label from a tf.Example and returns one
item in a K-V PCollection with the key as the label and the value as the
string-parsed tf.Example.
Args:
example: a tf.Example in serialized format, taken directly from a
TFRecordDataset.
label: string containing the name of the categorical variable that we are
extracting from the example.
Returns:
Tuple with two items. First item is a class label; second item is the input
tf.Example, deserialized and parsed from string format.
"""
class_label = None
parsed = tf.train.Example.FromString(example.numpy())
if parsed.features.feature[label].int64_list.value:
val = parsed.features.feature[label].int64_list.value
if len(val) > 0:
class_label = val[0]
else:
val = parsed.features.feature[label].bytes_list.value
if len(val) > 0:
class_label = val[0].decode()
return (class_label, parsed)