def _convert_comments_data_tfrecord()

in fairness_indicators/tutorial_utils/util.py [0:0]


def _convert_comments_data_tfrecord(input_filename, output_filename=None):
  """Convert the public civil comments data, for tfrecord data."""
  with tf.io.TFRecordWriter(output_filename) as writer:
    for serialized in tf.data.TFRecordDataset(filenames=[input_filename]):
      example = tf.train.Example()
      example.ParseFromString(serialized.numpy())
      if not example.features.feature[TEXT_FEATURE].bytes_list.value:
        continue

      new_example = tf.train.Example()
      new_example.features.feature[TEXT_FEATURE].bytes_list.value.extend(
          example.features.feature[TEXT_FEATURE].bytes_list.value)
      new_example.features.feature[LABEL].float_list.value.append(
          1 if example.features.feature[LABEL].float_list.value[0] >= _THRESHOLD
          else 0)

      for identity_category, identity_list in IDENTITY_COLUMNS.items():
        grouped_identity = []
        for identity in identity_list:
          if (example.features.feature[identity].float_list.value and
              example.features.feature[identity].float_list.value[0] >=
              _THRESHOLD):
            grouped_identity.append(identity.encode())
        new_example.features.feature[identity_category].bytes_list.value.extend(
            grouped_identity)
      writer.write(new_example.SerializeToString())

  return output_filename