in tensorflow_model_remediation/min_diff/losses/kernels/kernel_utils.py [0:0]
def _get_kernel(kernel: Union[base_kernel.MinDiffKernel, Text],
kernel_var_name: Text = 'kernel'):
"""Returns a `losses.MinDiffKernel` instance corresponding to `kernel`.
If `kernel` is an instance of `losses.MinDiffKernel` then it is returned
directly. If `kernel` is a string it must be an accepted kernel name. A
value of `None` is also accepted and simply returns `None`.
Args:
kernel: kernel instance. Can be `None`, a string or an instance of
`losses.MinDiffKernel`.
kernel_var_name: Name of the kernel variable. This is only used for error
messaging.
Returns:
Returns a MinDiffKernel instance.
"""
if kernel is None:
return None
if isinstance(kernel, base_kernel.MinDiffKernel):
return kernel
if isinstance(kernel, six.string_types):
lower_case_kernel = kernel.lower()
if lower_case_kernel in _STRING_TO_KERNEL_DICT:
return _STRING_TO_KERNEL_DICT[lower_case_kernel]()
raise ValueError('If {} is a string, it must be a (case-insensitive) '
'match for one of the following supported values: {}. '
'given: {}'.format(kernel_var_name,
_STRING_TO_KERNEL_DICT.keys(), kernel))
raise TypeError('{} must be either of type MinDiffKernel or string, given: '
'{} (type: {})'.format(kernel_var_name, kernel, type(kernel)))