def _prepare_loss_fns()

in neural_structured_learning/keras/adversarial_regularization.py [0:0]


def _prepare_loss_fns(loss, output_names):
  """Converts `loss` to a list of per-output loss functions or objects."""
  # losses for multiple outputs indexed by name
  if isinstance(loss, collections.Mapping):
    for name in output_names:
      if name not in loss:
        raise ValueError(
            'Loss for {} not found in `loss` dictionary.'.format(name))
    return [tf.keras.losses.get(loss[name]) for name in output_names]

  # loss for single output, or shared loss fn for multiple outputs
  if isinstance(loss, six.string_types):
    return [tf.keras.losses.get(loss) for _ in output_names]

  # losses for multiple outputs indexed by position
  if isinstance(loss, collections.Sequence):
    if len(loss) != len(output_names):
      raise ValueError('`loss` should have the same number of elements as '
                       'model output')
    return six.moves.map(tf.keras.losses.get, loss)

  # loss for single output, or shared loss fn for multiple outputs
  return [tf.keras.losses.get(loss) for _ in output_names]