in gym3/types.py [0:0]
def multimap(f: Callable, *xs: Any) -> Any:
"""
Apply f at each leaf of the list of trees
A tree is:
* a (possibly nested) dict
* a (possibly nested) DictType
* any other object (a leaf value)
`{"a": 1}`, `{"a": {"b": 2}}`, and `3` are all valid trees, where the leaf values
are the integers
:param f: function to call at each leaf, must take len(xs) arguments
:param xs: a list of trees, all with the same structure
:returns: A tree of the same structure, where each leaf contains f's return value.
"""
first = xs[0]
if isinstance(first, dict) or isinstance(first, DictType):
assert all(isinstance(x, dict) or isinstance(x, DictType) for x in xs)
assert all(sorted(x.keys()) == sorted(first.keys()) for x in xs)
return {k: multimap(f, *(x[k] for x in xs)) for k in sorted(first.keys())}
else:
return f(*xs)