def _get_relay_func()

in python/tvm/relay/frontend/tensorflow.py [0:0]


    def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None):
        """Construct relay nodes from tensorflow graph definition - GraphDef.

        Follow the tensorflow graph definition to parse and convert it to Relay.
        Some of the assumptions listed below.

            -> All Placeholders are considered as graph input.
            -> All Const nodes are params.
            -> Last node is assumed as graph output.
            -> _output_shapes : Graph should be frozen with add_shapes=True.
                                Or user can pass input shape dictionary optionally.
            -> DecodeJpeg, ResizeBilinear: These are dummy operators.
                                           Hence user should handle preprocessing outside.
            -> CheckNumerics: No implementation as of now for this.
                              Just copies input to output.

        Parameters
        ----------
        graph : tensorflow graph definition object
            The loaded tensorflow GraphDef

        layout : target layout to be used (Optional)
            NCHW only supported now to enable NHWC models on GPU.

        shape : Dictionary of input dimensions (Optional)
            Graph level input shape dictionary.

        outputs : List of output tensor names (Optional)
            if not specified then the last node is assumed as graph output.

        Returns
        -------
        mod : tvm.IRModule
            The module that optimizations will be performed on.

        params : dict
            A dict of name: tvm.nd.array pairs, used as pretrained weights
        """
        try:
            from tensorflow.python.framework import tensor_util
        except ImportError as e:
            raise ImportError("Unable to import tensorflow which is required {}".format(e))

        missing_operators = self._parse_import_prerequisites(graph)
        control_flow_nodes = []
        ta_write_nodes = []
        ta_gather_nodes = []
        ta_construct_nodes = []
        self._in_shape = shape
        self._layout = layout
        self._graph = graph

        if missing_operators:
            freezed_ops = [op for op in missing_operators if op in _freezed_graph_pruned_op_list]
            if freezed_ops:
                raise Exception(
                    "Graph is not frozen. Provide a frozen graph. "
                    "Found operators {}".format(freezed_ops)
                )

            raise NotImplementedError(
                "The following operators are not implemented: {}".format(missing_operators)
            )

        for node in graph.node:
            node_name_prefix = node.name.rsplit("/", 1)[0]
            self._control_flow_node_map[node_name_prefix].add(node.op)
            self._tf_node_map[node.name] = node

            # Parse output_shapes attribute
            parsed_attr = self._parse_attr(node.attr)
            if "_output_shapes" in parsed_attr:
                self._output_shapes[node.name] = [
                    tensor_util.TensorShapeProtoToList(tshape)
                    for tshape in parsed_attr["_output_shapes"]
                ]
            else:
                self._output_shapes[node.name] = [None]

            # Parse placeholder and const here since input shape info is required.
            if node.op == "Placeholder" or node.op == "PlaceholderWithDefault":
                # Give priority to user argument.
                if shape and node.name in shape:
                    self._input_shapes[node.name] = list(shape[node.name])
                else:
                    self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(
                        node.attr["shape"].shape
                    )
                    for idx, dim in enumerate(self._input_shapes[node.name]):
                        if dim < 0:
                            self._input_shapes[node.name][idx] = Any()

                self._output_shapes[node.name] = [self._input_shapes[node.name]]
                attr = self._parse_attr(node.attr)
                self._nodes[node.name] = [
                    _expr.var(
                        node.name, shape=self._input_shapes[node.name], dtype=attr["dtype"].name
                    )
                ]

                # Ignore user's input shape for Non placeholder
            elif node.op == "Const":
                tensor_value = node.attr["value"].tensor
                self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(
                    tensor_value.tensor_shape
                )
                self._output_shapes[node.name] = [self._input_shapes[node.name]]
                if shape and node.name in shape:
                    warnings.warn(
                        "Ignore the passed shape. Shape in graphdef "
                        "will be used for operator %s." % node.name
                    )
                for key, value in node.attr.items():
                    self._parse_param(key, value, node.name, self._in_shape)
            elif node.op in _control_flow_nodes:
                # We assume that the direct parent node of Exit is a while loop block
                if node.op == "Exit":
                    self._while_loop_name_set.add(node_name_prefix)
                control_flow_nodes.append(node)
            elif node.op.startswith("TensorArray"):
                if is_tensor_array_constuctor(node):
                    ta_construct_nodes.append(node)
                else:
                    for ta_write_name, idx in _tensor_array_write_ops.items():
                        if node.op.startswith(ta_write_name):
                            ta_write_nodes.append((node, idx))
                            break
                    if node.op.startswith("TensorArrayGather"):
                        ta_gather_nodes.append(node)

        # Use tensor array gather to infer static tensor array shape
        for gather_node in ta_gather_nodes:
            input_ta_name = gather_node.input[0]
            input_ta_node = self._tf_node_map[input_ta_name]
            if is_tensor_array_constuctor(input_ta_node):
                gather_attr = self._parse_attr(gather_node.attr)
                if "element_shape" not in gather_attr:
                    continue
                raw_elem_shape = tensor_util.TensorShapeProtoToList(gather_attr["element_shape"])
                elem_shape = []
                for dim in raw_elem_shape:
                    if dim < 0:
                        elem_shape.append(Any())
                    else:
                        elem_shape.append(int(dim))
                self._tensor_array_shapes[input_ta_node.name] = elem_shape

        # Fetch node contains static tensor array shape
        for item in ta_write_nodes:
            wnode = item[0]
            ta_idx, inode_idx = item[1]

            stack = [self._tf_node_map[wnode.input[ta_idx].split(":")[0]]]
            while stack:
                cnode = stack.pop(0)
                if not cnode.op.startswith("TensorArray"):
                    for iname in cnode.input:
                        stack.append(self._tf_node_map[iname.split(":")[0]])
                elif cnode.name != wnode.name:
                    if is_tensor_array_constuctor(cnode):
                        inode = self._tf_node_map[wnode.input[inode_idx].split(":")[0]]
                        tn = wnode.input[inode_idx].split(":")
                        output_index = int(tn[1]) if len(tn) > 1 else 0
                        self._tensor_array_shape_nodes[cnode.name] = (inode, wnode.op, output_index)
                    break

        # First, parse all control flow nodes.
        # Convert tf.cond to Branch and tf.while_loop to Loop.
        sorted_cf_nodes = []
        exit_pos_map = {}
        ordered_prefix = []
        # Sort control flow nodes to move all Exit nodes to the end
        # of corresponding while_loop block.
        for node in control_flow_nodes:
            loop_name = find_parent_loop_name(node.name, self._while_loop_name_set)
            if node.op == "Exit":
                if loop_name not in exit_pos_map:
                    ordered_prefix.append(loop_name)
                    exit_pos_map[loop_name] = len(sorted_cf_nodes)
                sorted_cf_nodes.append(node)
            elif loop_name in self._while_loop_name_set:
                if loop_name not in exit_pos_map:
                    sorted_cf_nodes.append(node)
                else:
                    sorted_cf_nodes.insert(exit_pos_map[loop_name], node)
                    for j in range(ordered_prefix.index(loop_name), len(ordered_prefix)):
                        exit_pos_map[ordered_prefix[j]] += 1
            else:
                sorted_cf_nodes.append(node)

        for node in sorted_cf_nodes:
            self._sorted_cf_node_names.append(node.name)

        for node in sorted_cf_nodes:
            self._backtrack_construct(node.name)

        # Second, parse other nodes to re-create TF graph using Relay operators.
        for node in graph.node:
            self._backtrack_construct(node.name)

        out = []
        if outputs is None:
            last_node = graph.node[-1]
            op = self._nodes[last_node.name.split(":")[0]]
            if last_node.op == "Exit":
                out = [op[0].tuple_value]
            else:
                out = op
        else:
            for out_name in outputs:
                if ":" in out_name:
                    out_name, out_num = out_name.split(":")
                    out_num = int(out_num)
                    out.append(self._nodes[out_name][out_num])
                else:
                    out.append(self._nodes[out_name][0])

        if isinstance(out, _expr.TupleWrapper):
            out = out.tuple_value
        else:
            out = out[0] if len(out) == 1 else _expr.Tuple(out)
        fvars = analysis.free_vars(out)
        func = _function.Function(fvars, out)
        final_params = {}
        for fv in fvars:
            if fv.name_hint in self._params:
                final_params[fv.name_hint] = self._params[fv.name_hint]
        self._params = final_params
        return func