in maddpg/common/tf_util.py [0:0]
def __call__(self, *args, **kwargs):
assert len(args) <= len(self.inputs), "Too many arguments provided"
feed_dict = {}
# Update the args
for inpt, value in zip(self.inputs, args):
self._feed_input(feed_dict, inpt, value)
# Update the kwargs
kwargs_passed_inpt_names = set()
for inpt in self.inputs[len(args):]:
inpt_name = inpt.name.split(':')[0]
inpt_name = inpt_name.split('/')[-1]
assert inpt_name not in kwargs_passed_inpt_names, \
"this function has two arguments with the same name \"{}\", so kwargs cannot be used.".format(inpt_name)
if inpt_name in kwargs:
kwargs_passed_inpt_names.add(inpt_name)
self._feed_input(feed_dict, inpt, kwargs.pop(inpt_name))
else:
assert inpt in self.givens, "Missing argument " + inpt_name
assert len(kwargs) == 0, "Function got extra arguments " + str(list(kwargs.keys()))
# Update feed dict with givens.
for inpt in self.givens:
feed_dict[inpt] = feed_dict.get(inpt, self.givens[inpt])
results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
if self.check_nan:
if any(np.isnan(r).any() for r in results):
raise RuntimeError("Nan detected")
return results