in tensorflow_model_remediation/min_diff/keras/models/min_diff_model.py [0:0]
def test_step(self, data, *args, **kwargs):
"""The logic for one evaluation step.
Has the exact same behavior as `tf.keras.Model.test_step` with the one
exception that it removes the 'min_diff_loss' metric(s) if `min_diff_data`
is not available.
"""
metrics = super(MinDiffModel, self).test_step(data, *args, **kwargs)
# If there is no min_diff_data, remove the min_diff_loss metric.
x, _, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
if self.unpack_min_diff_data(x) is None:
for metric in tf.nest.flatten(self._min_diff_loss_metric):
if metric.name in metrics:
del metrics[metric.name]
return metrics