def restore_compiler_failures()

in python/graph_def_util.py [0:0]


def restore_compiler_failures(compiled_graph_def, original_graph_def):
    """
    Restore `NeuronOp`'s that failed to compile.

    TODO: Some passes introduced recently can change subgraph input/output
    signatures. To deal with these cases properly, we need to obtain original
    NodeDef messages from `original_graph_def` instead of `subgraph_def`.
    """
    neuron_op_dict = {node.name: node for node in get_neuron_nodes(compiled_graph_def)}
    restore_nodes = []
    remove_node_names = set()
    gd_tensor_name_map = {}
    all_expected_node_names = {node.name for node in compiled_graph_def.node if node.op != tNeuronOp}
    for node in get_neuron_nodes(compiled_graph_def):
        if not node.attr[knExecutable].s:
            remove_node_names.add(node.name)
            subgraph_def = get_subgraph_def(node)
            sgd_tensor_name_map = {}
            for gd_ts_name, sg_ph_name in zip(node.input, node.attr[knInputNames].list.s):
                sgd_ph_name = format_tensor_name(sg_ph_name.decode())
                op_name, ts_index = _graph_def_op_index(gd_ts_name)
                if op_name in neuron_op_dict:
                    in_node = neuron_op_dict[op_name]
                    if not in_node.attr[knExecutable].s:
                        gd_ts_name = in_node.attr[knOutputNames].list.s[ts_index].decode()
                sgd_tensor_name_map[sgd_ph_name] = gd_ts_name
            for sg_node in subgraph_def.node:
                for idx, name in enumerate(sg_node.input):
                    sg_node.input[idx] = sgd_tensor_name_map.get(name, name)
                if sg_node.op != tPlaceholder:
                    restore_nodes.append(sg_node)
                    all_expected_node_names.add(sg_node.name)
            for out_idx, out_name in enumerate(node.attr[knOutputNames].list.s):
                out_gd_ts_name = format_tensor_name('{}:{}'.format(node.name, out_idx))
                gd_tensor_name_map[out_gd_ts_name] = format_tensor_name(out_name.decode())
    restore_node_names = {node.name for node in restore_nodes}
    remove_node_names.update(
        node.name for node in compiled_graph_def.node if node.name in restore_node_names)
    original_node_with_control_inputs = get_node_with_control_inputs(original_graph_def)
    for node in restore_nodes:
        if node.name in original_node_with_control_inputs:
            input_names = original_node_with_control_inputs[node.name]
            for name in input_names:
                if name.split(':')[0] in all_expected_node_names:
                    node.input.append(name)
    for node in compiled_graph_def.node:
        for idx, name in enumerate(node.input):
            node.input[idx] = gd_tensor_name_map.get(name, name)

    graph_def = graph_pb2.GraphDef()
    graph_def.node.extend(
        node for node in compiled_graph_def.node if node.name not in remove_node_names)
    graph_def.node.extend(node for node in restore_nodes)

    # remove illegal node names
    node_names = {node.name for node in graph_def.node}
    for node in graph_def.node:
        node.input[:] = [name for name in node.input if _graph_def_op_index(name)[0] in node_names]

    # preserve information for function-call operators (e. g., MapDataset)
    graph_def.library.CopyFrom(compiled_graph_def.library)
    return graph_def