def _create_tpu_estimator_spec()

in tensorflow_privacy/privacy/estimators/binary_class_head.py [0:0]


  def _create_tpu_estimator_spec(self,
                                 features,
                                 mode,
                                 logits,
                                 labels=None,
                                 optimizer=None,
                                 trainable_variables=None,
                                 train_op_fn=None,
                                 update_ops=None,
                                 regularization_losses=None):
    """See superclass for description."""

    with tf.compat.v1.name_scope(self._name, 'head'):
      # Predict.
      pred_keys = prediction_keys.PredictionKeys
      predictions = self.predictions(logits)
      if mode == ModeKeys.PREDICT:
        probabilities = predictions[pred_keys.PROBABILITIES]
        logistic = predictions[pred_keys.LOGISTIC]
        classifier_output = base_head.classification_output(
            scores=probabilities,
            n_classes=2,
            label_vocabulary=self._label_vocabulary)
        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
            mode=ModeKeys.PREDICT,
            predictions=predictions,
            export_outputs={
                base_head.DEFAULT_SERVING_KEY: classifier_output,
                base_head.CLASSIFY_SERVING_KEY: classifier_output,
                base_head.REGRESS_SERVING_KEY:
                    export_output.RegressionOutput(value=logistic),
                base_head.PREDICT_SERVING_KEY:
                    export_output.PredictOutput(predictions)
            })
      regularized_training_loss = self.loss(
          logits=logits,
          labels=labels,
          features=features,
          mode=mode,
          regularization_losses=regularization_losses)
      scalar_loss = tf.reduce_mean(regularized_training_loss)
      # Eval.
      if mode == ModeKeys.EVAL:
        eval_metrics = self.metrics(regularization_losses=regularization_losses)
        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
            mode=ModeKeys.EVAL,
            predictions=predictions,
            loss=scalar_loss,
            eval_metrics=base_head.create_eval_metrics_tuple(
                self.update_metrics, {
                    'eval_metrics': eval_metrics,
                    'features': features,
                    'logits': logits,
                    'labels': labels,
                    'regularization_losses': regularization_losses
                }))
      # Train.
      train_op = base_head.create_estimator_spec_train_op(
          head_name=self._name,
          optimizer=optimizer,
          train_op_fn=train_op_fn,
          update_ops=update_ops,
          trainable_variables=trainable_variables,
          regularized_training_loss=regularized_training_loss,
          loss_reduction=self._loss_reduction)
    # Create summary.
    base_head.create_estimator_spec_summary(
        regularized_training_loss=scalar_loss,
        regularization_losses=regularization_losses,
        summary_key_fn=self._summary_key)
    return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
        mode=ModeKeys.TRAIN,
        predictions=predictions,
        loss=scalar_loss,
        train_op=train_op)