def recursive_apply_dict()

in runtool/runtool/recurse_config.py [0:0]


def recursive_apply_dict(node: dict, fn: Callable) -> Any:
    """
    Applies `fn` to the node, if `fn` changes the node,
    the changes should be returned. If the `fn` does not change the node,
    it calls `recursive_apply` on the children of the node.

    In case the recursion on the children results in one or more
    `runtool.datatypes.Versions` objects, the cartesian product of these
    versions is calculated and a new `runtool.datatypes.Versions` object will be
    returned containing the different versions of this node.

    """

    # else merge children of type Versions into a new Versions object
    expanded_children = []
    new_node = {}
    for key, value in node.items():
        child = recursive_apply(value, fn)
        # If the child is a Versions object, map the key to all its versions,
        # child = Versions([1,2]),
        # key = ['a']
        # ->
        # (('a':1), ('a':2))
        if isinstance(child, Versions):
            expanded_children.append(itertools.product([key], child))
        else:
            new_node[key] = child
    if expanded_children:
        # example:
        # expanded_children = [(('a':1), ('a':2)), (('b':1), ('b':2))]
        # new_node = {"c": 3}
        # results in:
        # [
        #   {'a':1, 'b':1, 'c':3},
        #   {'a':1, 'b':2, 'c':3},
        #   {'a':2, 'b':1, 'c':3},
        #   {'a':3, 'b':2, 'c':3},
        # ]
        new_node = [
            fn(
                dict(version_of_node, **new_node)
            )  # apply fn to the new version of the node
            for version_of_node in itertools.product(*expanded_children)
        ]

        # if the current node generated Versions object, these
        # need to be flattened as well. For example:
        # new_node = [Versions([1,2]), Versions([3,4])]
        # results in
        # Versions([[1,3], [1,4], [2,3], [2,4]])
        if all(isinstance(val, Versions) for val in new_node):
            return Versions(list(*itertools.product(*new_node)))
        return Versions(new_node)
    return fn(new_node)