in tensorflow_model_remediation/min_diff/losses/loss_utils.py [0:0]
def _get_loss(loss: Union[base_loss.MinDiffLoss, Text],
loss_var_name: Text = 'loss') -> base_loss.MinDiffLoss:
"""Returns a `losses.MinDiffLoss` instance corresponding to `loss`.
If `loss` is an instance of `losses.MinDiffLoss` then it is returned
directly. If `loss` is a string it must be an accepted loss name. A
value of `None` is also accepted and simply returns `None`.
Args:
loss: loss instance. Can be `None`, a string or an instance of
`losses.MinDiffLoss`.
loss_var_name: Name of the loss variable. This is only used for error
messaging.
Returns:
A `MinDiffLoss` instance.
Raises:
ValueError: If `loss` is an unrecognized string.
TypeError: If `loss` is not an instance of `losses.MinDiffLoss` or a string.
"""
if loss is None:
return None
if isinstance(loss, base_loss.MinDiffLoss):
return loss
if isinstance(loss, six.string_types):
lower_case_loss = loss.lower()
if lower_case_loss in _STRING_TO_LOSS_DICT:
return _STRING_TO_LOSS_DICT[lower_case_loss]()
raise ValueError('If {} is a string, it must be a (case-insensitive) '
'match for one of the following supported values: {}. '
'given: {}'.format(loss_var_name,
_STRING_TO_LOSS_DICT.keys(), loss))
raise TypeError('{} must be either of type MinDiffLoss or string, given: '
'{} (type: {})'.format(loss_var_name, loss, type(loss)))