in fairness_indicators/tutorial_utils/util.py [0:0]
def _convert_comments_data_tfrecord(input_filename, output_filename=None):
"""Convert the public civil comments data, for tfrecord data."""
with tf.io.TFRecordWriter(output_filename) as writer:
for serialized in tf.data.TFRecordDataset(filenames=[input_filename]):
example = tf.train.Example()
example.ParseFromString(serialized.numpy())
if not example.features.feature[TEXT_FEATURE].bytes_list.value:
continue
new_example = tf.train.Example()
new_example.features.feature[TEXT_FEATURE].bytes_list.value.extend(
example.features.feature[TEXT_FEATURE].bytes_list.value)
new_example.features.feature[LABEL].float_list.value.append(
1 if example.features.feature[LABEL].float_list.value[0] >= _THRESHOLD
else 0)
for identity_category, identity_list in IDENTITY_COLUMNS.items():
grouped_identity = []
for identity in identity_list:
if (example.features.feature[identity].float_list.value and
example.features.feature[identity].float_list.value[0] >=
_THRESHOLD):
grouped_identity.append(identity.encode())
new_example.features.feature[identity_category].bytes_list.value.extend(
grouped_identity)
writer.write(new_example.SerializeToString())
return output_filename