in tensorflow_privacy/privacy/estimators/multi_label_head.py [0:0]
def _create_tpu_estimator_spec(self,
features,
mode,
logits,
labels=None,
optimizer=None,
trainable_variables=None,
train_op_fn=None,
update_ops=None,
regularization_losses=None):
"""See superclass for description."""
with tf.compat.v1.name_scope(self._name, 'head'):
# Predict.
pred_keys = prediction_keys.PredictionKeys
predictions = self.predictions(logits)
if mode == ModeKeys.PREDICT:
probabilities = predictions[pred_keys.PROBABILITIES]
classifier_output = base_head.classification_output(
scores=probabilities,
n_classes=self._n_classes,
label_vocabulary=self._label_vocabulary)
return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
base_head.DEFAULT_SERVING_KEY:
classifier_output,
base_head.CLASSIFY_SERVING_KEY:
classifier_output,
base_head.PREDICT_SERVING_KEY:
export_output.PredictOutput(predictions)
})
regularized_training_loss = self.loss(
logits=logits,
labels=labels,
features=features,
mode=mode,
regularization_losses=regularization_losses)
scalar_loss = tf.reduce_mean(regularized_training_loss)
# Eval.
if mode == ModeKeys.EVAL:
eval_metrics = self.metrics(regularization_losses=regularization_losses)
return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=ModeKeys.EVAL,
predictions=predictions,
loss=scalar_loss,
eval_metrics=base_head.create_eval_metrics_tuple(
self.update_metrics, {
'eval_metrics': eval_metrics,
'features': features,
'logits': logits,
'labels': labels,
'regularization_losses': regularization_losses
}))
# Train.
train_op = base_head.create_estimator_spec_train_op(
head_name=self._name,
optimizer=optimizer,
train_op_fn=train_op_fn,
update_ops=update_ops,
trainable_variables=trainable_variables,
regularized_training_loss=regularized_training_loss,
loss_reduction=self._loss_reduction)
# Create summary.
base_head.create_estimator_spec_summary(
regularized_training_loss=scalar_loss,
regularization_losses=regularization_losses,
summary_key_fn=self._summary_key)
return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=ModeKeys.TRAIN,
predictions=predictions,
loss=scalar_loss,
train_op=train_op)