in tensorflow_gan/python/losses/tuple_losses.py [0:0]
def args_to_gan_model(loss_fn):
"""Converts a loss taking individual args to one taking a GANModel namedtuple.
The new function has the same name as the original one.
Args:
loss_fn: A python function taking a `GANModel` object and returning a loss
Tensor calculated from that object. The shape of the loss depends on
`reduction`.
Returns:
A new function that takes a GANModel namedtuples and returns the same loss.
"""
# Match arguments in `loss_fn` to elements of `namedtuple`.
# TODO(joelshor): Properly handle `varargs` and `keywords`.
signature_params = inspect.signature(loss_fn).parameters
required_args = set()
default_args_dict = {}
for name, arg in signature_params.items():
if arg.default == arg.empty:
required_args.add(name)
else:
default_args_dict[name] = arg.default
def new_loss_fn(gan_model, **kwargs): # pylint:disable=missing-docstring
def _asdict(namedtuple):
"""Returns a namedtuple as a dictionary.
This is required because `_asdict()` in Python 3.x.x is broken in classes
that inherit from `collections.namedtuple`. See
https://bugs.python.org/issue24931 for more details.
Args:
namedtuple: An object that inherits from `collections.namedtuple`.
Returns:
A dictionary version of the tuple.
"""
return {k: getattr(namedtuple, k) for k in namedtuple._fields}
gan_model_dict = _asdict(gan_model)
# Make sure non-tuple required args are supplied.
args_from_tuple = set(signature_params.keys()).intersection(
set(gan_model._fields))
required_args_not_from_tuple = required_args - args_from_tuple
for arg in required_args_not_from_tuple:
if arg not in kwargs:
raise ValueError('`%s` must be supplied to %s loss function.' % (
arg, loss_fn.__name__))
# Make sure tuple args aren't also supplied as keyword args.
ambiguous_args = set(gan_model._fields).intersection(set(kwargs.keys()))
if ambiguous_args:
raise ValueError(
'The following args are present in both the tuple and keyword args '
'for %s: %s' % (loss_fn.__name__, ambiguous_args))
# Add required args to arg dictionary.
required_args_from_tuple = required_args.intersection(args_from_tuple)
for arg in required_args_from_tuple:
assert arg not in kwargs
kwargs[arg] = gan_model_dict[arg]
# Add arguments that have defaults.
for arg in default_args_dict:
val_from_tuple = gan_model_dict[arg] if arg in gan_model_dict else None
val_from_kwargs = kwargs[arg] if arg in kwargs else None
assert not (val_from_tuple is not None and val_from_kwargs is not None)
if val_from_tuple is not None:
kwargs[arg] = val_from_tuple
else:
if val_from_kwargs is not None:
kwargs[arg] = val_from_kwargs
else:
kwargs[arg] = default_args_dict[arg]
return loss_fn(**kwargs)
new_docstring = """The gan_model version of %s.""" % loss_fn.__name__
new_loss_fn.__docstring__ = new_docstring
new_loss_fn.__name__ = loss_fn.__name__
new_loss_fn.__module__ = loss_fn.__module__
return new_loss_fn