def _create_tpu_estimator_spec()

in tensorflow_privacy/privacy/estimators/v1/head.py [0:0]


  def _create_tpu_estimator_spec(self,
                                 features,
                                 mode,
                                 logits,
                                 labels=None,
                                 optimizer=None,
                                 train_op_fn=None,
                                 regularization_losses=None):
    """Returns an `EstimatorSpec`.

    Args:
      features: Input `dict` of `Tensor` or `SparseTensor` objects.
      mode: Estimator's `ModeKeys`.
      logits: logits `Tensor` with shape `[D0, D1, ... DN, 1]`. For many
        applications, the shape is `[batch_size, 1]`.
      labels: Labels integer or string `Tensor` with shape matching `logits`,
        namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required
        argument when `mode` equals `TRAIN` or `EVAL`.
      optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.
        Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which
        updates variables and increments `global_step`.
      train_op_fn: Function that takes a scalar loss `Tensor` and returns
        `train_op`. Used if `optimizer` is `None`.
      regularization_losses: A list of additional scalar losses to be added to
        the training loss, such as regularization losses. These losses are
        usually expressed as a batch average, so for best results users need to
        set `loss_reduction=SUM_OVER_BATCH_SIZE` when creating the head to avoid
        scaling errors.

    Returns:
      `EstimatorSpec`.
    Raises:
      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
        mode, or if both are set.
    """
    # Predict.
    with tf.compat.v1.name_scope(self._name, 'head'):
      with tf.compat.v1.name_scope(None, 'predictions', (logits,)):
        pred_keys = prediction_keys.PredictionKeys
        logits = _check_logits_final_dim(logits, self.logits_dimension)
        logistic = tf.math.sigmoid(logits, name=pred_keys.LOGISTIC)
        two_class_logits = tf.concat((tf.compat.v1.zeros_like(logits), logits),
                                     axis=-1,
                                     name='two_class_logits')
        probabilities = tf.compat.v1.nn.softmax(
            two_class_logits, name=pred_keys.PROBABILITIES)
        class_ids = tf.compat.v1.math.argmax(
            two_class_logits, axis=-1, name=pred_keys.CLASS_IDS)
        class_ids = tf.compat.v1.expand_dims(class_ids, axis=-1)
        all_class_ids = _all_class_ids(logits, n_classes=2)
        all_classes = _all_classes(
            logits, n_classes=2, label_vocabulary=self._label_vocabulary)

        if self._label_vocabulary:
          table = lookup_ops.index_to_string_table_from_tensor(
              vocabulary_list=self._label_vocabulary,
              name='class_string_lookup')
          classes = table.lookup(class_ids)
        else:
          classes = tf.strings.as_string(class_ids, name='str_classes')
        predictions = {
            pred_keys.LOGITS: logits,
            pred_keys.LOGISTIC: logistic,
            pred_keys.PROBABILITIES: probabilities,
            pred_keys.CLASS_IDS: class_ids,
            pred_keys.CLASSES: classes,
            pred_keys.ALL_CLASS_IDS: all_class_ids,
            pred_keys.ALL_CLASSES: all_classes,
        }
      if mode == ModeKeys.PREDICT:
        classifier_output = _classification_output(
            scores=probabilities,
            n_classes=2,
            label_vocabulary=self._label_vocabulary)
        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
            mode=ModeKeys.PREDICT,
            predictions=predictions,
            export_outputs={
                _DEFAULT_SERVING_KEY: classifier_output,
                _CLASSIFY_SERVING_KEY: classifier_output,
                _REGRESS_SERVING_KEY: export_output.RegressionOutput(
                    value=logistic),
                _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions)
            })

      (training_loss, unreduced_loss, weights, processed_labels) = (
          self.create_loss(
              features=features, mode=mode, logits=logits, labels=labels))
      if regularization_losses:
        regularization_loss = tf.math.add_n(regularization_losses)
        regularized_training_loss = tf.math.add_n(
            [training_loss, regularization_loss])
      else:
        regularization_loss = None
        regularized_training_loss = training_loss

      if self._loss_reduction == tf.compat.v1.losses.Reduction.NONE:
        scalar_loss = tf.reduce_mean(regularized_training_loss)
      else:
        scalar_loss = regularized_training_loss
      # Eval.
      if mode == ModeKeys.EVAL:
        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
            mode=ModeKeys.EVAL,
            predictions=predictions,
            loss=scalar_loss,
            eval_metrics=_create_eval_metrics_tuple(
                self._eval_metric_ops, {
                    'labels': processed_labels,
                    'logits': logits,
                    'logistic': logistic,
                    'class_ids': class_ids,
                    'weights': weights,
                    'unreduced_loss': unreduced_loss,
                    'regularization_loss': regularization_loss
                }))

      # Train.
      if optimizer is not None:
        if train_op_fn is not None:
          raise ValueError('train_op_fn and optimizer cannot both be set.')
        train_op = optimizer.minimize(
            regularized_training_loss,
            global_step=tf.compat.v1.train.get_global_step())
      elif train_op_fn is not None:
        train_op = train_op_fn(regularized_training_loss)
      else:
        raise ValueError('train_op_fn and optimizer cannot both be None.')
      train_op = _append_update_ops(train_op)
      # Only summarize mean_loss for SUM reduction to preserve backwards
      # compatibility. Otherwise skip it to avoid unnecessary computation.
      if self._loss_reduction == tf.compat.v1.losses.Reduction.SUM:
        example_weight_sum = tf.math.reduce_sum(
            weights * tf.compat.v1.ones_like(unreduced_loss))
        mean_loss = training_loss / example_weight_sum
      else:
        mean_loss = None
    with tf.compat.v1.name_scope(''):
      keys = metric_keys.MetricKeys
      tf.compat.v1.summary.scalar(
          _summary_key(self._name, keys.LOSS), scalar_loss)
      if mean_loss is not None:
        tf.compat.v1.summary.scalar(
            _summary_key(self._name, keys.LOSS_MEAN), mean_loss)
      if regularization_loss is not None:
        tf.compat.v1.summary.scalar(
            _summary_key(self._name, keys.LOSS_REGULARIZATION),
            regularization_loss)
    return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
        mode=ModeKeys.TRAIN,
        predictions=predictions,
        loss=scalar_loss,
        train_op=train_op)