def multimap()

in gym3/types.py [0:0]


def multimap(f: Callable, *xs: Any) -> Any:
    """
    Apply f at each leaf of the list of trees

    A tree is:
        * a (possibly nested) dict
        * a (possibly nested) DictType
        * any other object (a leaf value)

    `{"a": 1}`, `{"a": {"b": 2}}`, and `3` are all valid trees, where the leaf values
    are the integers

    :param f: function to call at each leaf, must take len(xs) arguments
    :param xs: a list of trees, all with the same structure

    :returns: A tree of the same structure, where each leaf contains f's return value.
    """
    first = xs[0]
    if isinstance(first, dict) or isinstance(first, DictType):
        assert all(isinstance(x, dict) or isinstance(x, DictType) for x in xs)
        assert all(sorted(x.keys()) == sorted(first.keys()) for x in xs)
        return {k: multimap(f, *(x[k] for x in xs)) for k in sorted(first.keys())}
    else:
        return f(*xs)