in tensorflow_model_remediation/min_diff/losses/base_loss.py [0:0]
def __init__(self,
membership_transform=None,
predictions_transform=None,
membership_kernel=None,
predictions_kernel=None,
name: Optional[Text] = None):
"""Initialize `MinDiffLoss` instance.
Raises:
ValueError: If a `*_transform` parameter is passed in but is not callable.
ValueError: If a `*_kernel` parameter has an unrecognized type or value.
"""
super(MinDiffLoss, self).__init__(
reduction=tf.keras.losses.Reduction.NONE, name=name)
self.name = name or _to_snake_case(self.__class__.__name__)
_validate_transform(membership_transform, 'membership_transform')
self.membership_transform = (membership_transform)
_validate_transform(predictions_transform, 'predictions_transform')
self.predictions_transform = predictions_transform
self.membership_kernel = kernel_utils._get_kernel(membership_kernel,
'membership_kernel')
self.predictions_kernel = kernel_utils._get_kernel(predictions_kernel,
'predictions_kernel')