in neural_structured_learning/keras/adversarial_regularization.py [0:0]
def _prepare_loss_fns(loss, output_names):
"""Converts `loss` to a list of per-output loss functions or objects."""
# losses for multiple outputs indexed by name
if isinstance(loss, collections.Mapping):
for name in output_names:
if name not in loss:
raise ValueError(
'Loss for {} not found in `loss` dictionary.'.format(name))
return [tf.keras.losses.get(loss[name]) for name in output_names]
# loss for single output, or shared loss fn for multiple outputs
if isinstance(loss, six.string_types):
return [tf.keras.losses.get(loss) for _ in output_names]
# losses for multiple outputs indexed by position
if isinstance(loss, collections.Sequence):
if len(loss) != len(output_names):
raise ValueError('`loss` should have the same number of elements as '
'model output')
return six.moves.map(tf.keras.losses.get, loss)
# loss for single output, or shared loss fn for multiple outputs
return [tf.keras.losses.get(loss) for _ in output_names]