def test_step()

in tensorflow_model_remediation/min_diff/keras/models/min_diff_model.py [0:0]


  def test_step(self, data, *args, **kwargs):

    """The logic for one evaluation step.

    Has the exact same behavior as `tf.keras.Model.test_step` with the one
    exception that it removes the 'min_diff_loss' metric(s) if `min_diff_data`
    is not available.
    """
    metrics = super(MinDiffModel, self).test_step(data, *args, **kwargs)
    # If there is no min_diff_data, remove the min_diff_loss metric.
    x, _, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
    if self.unpack_min_diff_data(x) is None:
      for metric in tf.nest.flatten(self._min_diff_loss_metric):
        if metric.name in metrics:
          del metrics[metric.name]
    return metrics