def _get_kernel()

in tensorflow_model_remediation/min_diff/losses/kernels/kernel_utils.py [0:0]


def _get_kernel(kernel: Union[base_kernel.MinDiffKernel, Text],
                kernel_var_name: Text = 'kernel'):
  """Returns a `losses.MinDiffKernel` instance corresponding to `kernel`.

  If `kernel` is an instance of `losses.MinDiffKernel` then it is returned
  directly. If `kernel` is a string it must be an accepted kernel name. A
  value of `None` is also accepted and simply returns `None`.

  Args:
    kernel: kernel instance. Can be `None`, a string or an instance of
      `losses.MinDiffKernel`.
    kernel_var_name: Name of the kernel variable. This is only used for error
      messaging.

  Returns:
    Returns a MinDiffKernel instance.
  """
  if kernel is None:
    return None
  if isinstance(kernel, base_kernel.MinDiffKernel):
    return kernel
  if isinstance(kernel, six.string_types):
    lower_case_kernel = kernel.lower()
    if lower_case_kernel in _STRING_TO_KERNEL_DICT:
      return _STRING_TO_KERNEL_DICT[lower_case_kernel]()
    raise ValueError('If {} is a string, it must be a (case-insensitive) '
                     'match for one of the following supported values: {}. '
                     'given: {}'.format(kernel_var_name,
                                        _STRING_TO_KERNEL_DICT.keys(), kernel))
  raise TypeError('{} must be either of type MinDiffKernel or string, given: '
                  '{} (type: {})'.format(kernel_var_name, kernel, type(kernel)))