def minimizer()

in functorch/_src/fx_minifier.py [0:0]


def minimizer(fail_f: fx.GraphModule, inps, module_fails):
    """
    Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.

    Does 2 main strategies:
    1. Truncates suffix: Removes some suffix from the graph and sets a new output.
    2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,
        tries replacing quarter of the graph, etc.

    >>> failing_function = fx.symbolic_trace(f)
    >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))

    note: module_fails returns True if it fails.
    """
    failing_graph = fail_f.graph
    cur_size = len(failing_graph.nodes)

    def graph_fails(graph, inps):

        mod = fx.GraphModule(fail_f, graph)
        mod.graph.lint()
        return module_fails(mod, inps)

    ConcreteProp(fail_f).propagate(*inps)
    if not graph_fails(failing_graph, inps):
        raise RuntimeError("Input graph did not fail the tester")
    print(f"Started off with {cur_size} nodes")

    def remove_suffix(cur_graph, cur_inps):
        print("Strategy: Remove suffix")
        assert graph_fails(cur_graph, cur_inps)
        gap = 2**math.floor(math.log2(len(cur_graph.nodes)))
        tested = set()
        while gap >= 1:
            new_graph = fx.Graph()
            env = {}
            for idx, node in enumerate(cur_graph.nodes):
                new_node = new_graph.node_copy(node, lambda x: env[x])
                if node.op not in ['placeholder', 'output']:
                    if idx % gap == 0 and idx not in tested:
                        output_node = new_graph.output((new_node,))
                        if graph_fails(new_graph, cur_inps) and len(new_graph.nodes) < len(cur_graph.nodes):
                            print()
                            print(f"SUCCESS: Removed [{idx}:{len(cur_graph.nodes)})")
                            return (new_graph, cur_inps), True
                        else:
                            tested.add(idx)
                            new_graph.erase_node(output_node)
                env[node] = new_node
            gap //= 2
        print("FAIL: Could not remove suffix")
        return (cur_graph, cur_inps), False

    def remove_unused_inputs(cur_graph, cur_inps):
        assert graph_fails(cur_graph, cur_inps)
        ph_nodes = _get_placeholders(cur_graph)
        if len(ph_nodes) != len(cur_inps):
            print(cur_graph)
            print(len(cur_inps))
        assert len(ph_nodes) == len(cur_inps)

        new_inps = []
        for idx in range(len(ph_nodes)):
            if len(ph_nodes[idx].users) == 0:
                cur_graph.erase_node(ph_nodes[idx])
            else:
                new_inps.append(cur_inps[idx])

        if len(new_inps) < len(cur_inps) and graph_fails(cur_graph, new_inps):
            print("Strategy: Remove unused inputs")
            print(f"SUCCESS: Went from {len(cur_inps)} inputs to {len(new_inps)} inputs")
            return (cur_graph, new_inps), True
        else:
            return (cur_graph, new_inps), False

    def eliminate_dead_code(cur_graph, cur_inps):
        orig_size = len(cur_graph.nodes)
        if cur_graph.eliminate_dead_code() and graph_fails(cur_graph, cur_inps):
            print("Strategy: Eliminate dead code")
            print(f"SUCCESS: Went from {orig_size} nodes to {len(cur_graph.nodes)} nodes")
            return (cur_graph, cur_inps), True
        else:
            return (cur_graph, cur_inps), False

    def consolidate_placeholders(cur_graph):
        new_graph = fx.Graph()
        env = {}
        for node in cur_graph.nodes:
            if node.op == 'placeholder':
                new_node = new_graph.node_copy(node, lambda x: env[x])
                env[node] = new_node

        for node in cur_graph.nodes:
            if node.op != 'placeholder':
                new_node = new_graph.node_copy(node, lambda x: env[x])
                env[node] = new_node
        return new_graph

    def delta_debugging(cur_graph: fx.Graph, cur_inps):
        print("Strategy: Delta Debugging")
        assert graph_fails(cur_graph, cur_inps)
        starting_placeholders = len(_get_placeholders(cur_graph))
        num_nodes = len(cur_graph.nodes)
        gap = int(2**math.floor(math.log2(num_nodes)))
        while gap >= 1:
            for start_range in range(0, num_nodes, gap):
                is_removing = False
                new_graph = copy.deepcopy(cur_graph)
                new_inps = cur_inps[:]
                end_range = min(num_nodes, start_range + gap)
                for idx in range(start_range, end_range):
                    new_node = list(new_graph.nodes)[idx]
                    if new_node.op not in ['placeholder', 'output']:
                        is_removing = True
                        _convert_node_to_placeholder(new_node, new_inps)
                if not is_removing:
                    continue
                new_graph = consolidate_placeholders(new_graph)
                if graph_fails(new_graph, new_inps):
                    print(
                        f"SUCCESS: Removed ({start_range}:{end_range}] - Went from {starting_placeholders} "
                        f"placeholders to {len(_get_placeholders(new_graph))}"
                    )
                    return (new_graph, new_inps), True
            gap //= 2

        print("FAIL: Could not remove prefix")
        return (cur_graph, inps), False

    print("###################")
    print(f"Current size: {len(failing_graph.nodes)}")
    print("###################")
    while True:
        any_succeeded = False
        strategies = [
            remove_suffix, eliminate_dead_code, remove_unused_inputs,
            delta_debugging, eliminate_dead_code, remove_unused_inputs
        ]
        for strategy in strategies:
            out = strategy(copy.deepcopy(failing_graph), inps[:])
            (cur_graph, cur_inps), succeeded = out
            if succeeded:
                print()
                print("###################")
                print(f"Current size: {len(cur_graph.nodes)}")
                print("###################")
                failing_graph = cur_graph
                inps = cur_inps
                any_succeeded = True

        if not any_succeeded:
            break
    failing_fx = fx.GraphModule(fail_f, failing_graph)
    print(failing_fx.code)
    print([(i.shape, i.dtype) for i in inps])
    return failing_fx, inps