def main()

in tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelabs/example.py [0:0]


def main(unused_argv):
  epoch_results = data_structures.AttackResultsCollection([])

  num_epochs = 2
  models = {
      "two layer model": two_layer_model,
      "three layer model": three_layer_model,
  }
  for model_name in models:
    # Incrementally train the model and store privacy metrics every num_epochs.
    for i in range(1, 6):
      models[model_name].fit(
          training_features,
          tf.keras.utils.to_categorical(training_labels, num_clusters),
          validation_data=(test_features,
                           tf.keras.utils.to_categorical(
                               test_labels, num_clusters)),
          batch_size=64,
          epochs=num_epochs,
          shuffle=True)

      training_pred = models[model_name].predict(training_features)
      test_pred = models[model_name].predict(test_features)

      # Add metadata to generate a privacy report.
      privacy_report_metadata = data_structures.PrivacyReportMetadata(
          accuracy_train=metrics.accuracy_score(
              training_labels, np.argmax(training_pred, axis=1)),
          accuracy_test=metrics.accuracy_score(test_labels,
                                               np.argmax(test_pred, axis=1)),
          epoch_num=num_epochs * i,
          model_variant_label=model_name)

      attack_results = mia.run_attacks(
          data_structures.AttackInputData(
              labels_train=training_labels,
              labels_test=test_labels,
              probs_train=training_pred,
              probs_test=test_pred,
              loss_train=crossentropy(training_labels, training_pred),
              loss_test=crossentropy(test_labels, test_pred)),
          data_structures.SlicingSpec(entire_dataset=True, by_class=True),
          attack_types=(data_structures.AttackType.THRESHOLD_ATTACK,
                        data_structures.AttackType.LOGISTIC_REGRESSION),
          privacy_report_metadata=privacy_report_metadata)
      epoch_results.append(attack_results)

  # Generate privacy reports
  epoch_figure = privacy_report.plot_by_epochs(epoch_results, [
      data_structures.PrivacyMetric.ATTACKER_ADVANTAGE,
      data_structures.PrivacyMetric.AUC
  ])
  epoch_figure.show()
  privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy(
      epoch_results, [
          data_structures.PrivacyMetric.ATTACKER_ADVANTAGE,
          data_structures.PrivacyMetric.AUC
      ])
  privacy_utility_figure.show()

  # Example of saving the results to the file and loading them back.
  with tempfile.TemporaryDirectory() as tmpdirname:
    filepath = os.path.join(tmpdirname, "results.pickle")
    attack_results.save(filepath)
    loaded_results = data_structures.AttackResults.load(filepath)
    print(loaded_results.summary(by_slices=False))

  # Print attack metrics
  for attack_result in attack_results.single_attack_results:
    print("Slice: %s" % attack_result.slice_spec)
    print("Attack type: %s" % attack_result.attack_type)
    print("AUC: %.2f" % attack_result.roc_curve.get_auc())

    print("Attacker advantage: %.2f\n" %
          attack_result.roc_curve.get_attacker_advantage())

  max_auc_attacker = attack_results.get_result_with_max_auc()
  print("Attack type with max AUC: %s, AUC of %.2f" %
        (max_auc_attacker.attack_type, max_auc_attacker.roc_curve.get_auc()))

  max_advantage_attacker = attack_results.get_result_with_max_attacker_advantage(
  )
  print("Attack type with max advantage: %s, Attacker advantage of %.2f" %
        (max_advantage_attacker.attack_type,
         max_advantage_attacker.roc_curve.get_attacker_advantage()))

  # Print summary
  print("Summary without slices: \n")
  print(attack_results.summary(by_slices=False))

  print("Summary by slices: \n")
  print(attack_results.summary(by_slices=True))

  # Print pandas data frame
  print("Pandas frame: \n")
  pd.set_option("display.max_rows", None, "display.max_columns", None)
  print(attack_results.calculate_pd_dataframe())

  # Example of ROC curve plotting.
  figure = plotting.plot_roc_curve(
      attack_results.single_attack_results[0].roc_curve)
  figure.show()
  plt.show()