def download_and_process_civil_comments_data()

in tensorflow_model_remediation/tools/tutorials_utils/min_diff_keras_utils.py [0:0]


def download_and_process_civil_comments_data():
  """Download and process the civil comments dataset into a Pandas DataFrame."""

  # Download data.
  toxicity_data_url = 'https://storage.googleapis.com/civil_comments_dataset/'
  train_csv_file = tf.keras.utils.get_file(
      'train_df_processed.csv', toxicity_data_url + 'train_df_processed.csv')
  validate_csv_file = tf.keras.utils.get_file(
      'validate_df_processed.csv',
      toxicity_data_url + 'validate_df_processed.csv')

  # Get validation data as TFRecords.
  validate_tfrecord_file = tf.keras.utils.get_file(
      'validate_tf_processed.tfrecord',
      toxicity_data_url + 'validate_tf_processed.tfrecord')

  # Read data into Pandas DataFrame.
  data_train = pd.read_csv(train_csv_file)
  data_validate = pd.read_csv(validate_csv_file)

  # Fix type interpretation.
  data_train[TEXT_FEATURE] = data_train[TEXT_FEATURE].astype(str)
  data_validate[TEXT_FEATURE] = data_validate[TEXT_FEATURE].astype(str)

  # Shape labels to match output.
  labels_train = data_train[LABEL].values.reshape(-1, 1) * 1.0
  labels_validate = data_validate[LABEL].values.reshape(-1, 1) * 1.0

  return data_train, data_validate, validate_tfrecord_file, labels_train, labels_validate