def inline_shape_inputs_in_subgraphs()

in python/graph_def_util.py [0:0]


def inline_shape_inputs_in_subgraphs(graph_def):
    """
    If NeuronOp has inputs that come from shape-related operators, then inline
    them as constants in the subgraph and remove them from the input signature.

    Note that the current approach only deals with inputs that come from
    shape-related operators directly. In theory, a safer approach is to copy
    the entire graph, turn all inferrable shapes into constants, and then
    propagate through constant folding. It is not practical at this point as
    copying the entire graph would consume too much memory and it will become
    practical once we don't have to freeze the entire graph.
    """
    name_to_node = {node.name: node for node in graph_def.node}
    shape_content_fn_map = {'Shape': TensorShape.as_list, 'Size': TensorShape.num_elements}

    def get_node(name):
        node_name, _ = split_tensor_name(name)
        return name_to_node[node_name]

    def contains_shape_input(node):
        # Returns True if any non-control input node is a shape-related operator
        return any(get_node(name).op in shape_content_fn_map for name in node.input if not name.startswith('^'))

    if not any(contains_shape_input(node) for node in get_neuron_nodes(graph_def)):
        return graph_def
    for node in get_neuron_nodes(graph_def):
        subgraph_def = get_subgraph_def(node)
        subgraph_name_to_node = {sn.name: sn for sn in subgraph_def.node}
        attr = node.attr
        discards = set()
        for idx, (input_name, ph_name) in enumerate(zip(node.input, attr[knInputNames].list.s)):
            input_node = get_node(input_name)
            if input_node.op in shape_content_fn_map:
                shape_input_name, = input_node.input
                shape_input_node_name, port = split_tensor_name(shape_input_name)
                shape_input_node = name_to_node[shape_input_node_name]
                shape_attr = shape_input_node.attr.get(kNeuronInferredShapes, None)
                if shape_attr is None:
                    shape_attr = shape_input_node.attr.get(knOutputShapes, None)
                if shape_attr is None:
                    continue
                shape_proto = shape_attr.list.shape[port]
                shape = TensorShape(shape_proto)
                dtype_enum = input_node.attr['out_type'].type
                dtype = dtypes.as_dtype(dtype_enum)
                tensor_content = shape_content_fn_map[input_node.op](shape)
                shape_tensor = convert_to_tensor(tensor_content, dtype)
                ph_node_name, _ = split_tensor_name(ph_name.decode())
                ph_node = subgraph_name_to_node[ph_node_name]
                ph_node.attr['dtype'].type = dtype_enum
                ph_node.attr.pop('shape')
                tensor_proto = ph_node.attr['value'].tensor
                tensor_proto.dtype = dtype_enum
                tensor_proto.tensor_shape.CopyFrom(shape_tensor.shape.as_proto())
                tensor_proto.tensor_content = shape_tensor.numpy().tobytes()
                ph_node.op = 'Const'
                discards.add(idx)
        if not discards:
            continue

        def maybe_discard_from_scalar_container(container):
            if container:
                container[:] = [value for idx, value in enumerate(container) if idx not in discards]

        def maybe_discard_from_composite_container(container):
            if container:
                new_values = [value for idx, value in enumerate(container) if idx not in discards]
                while container:
                    container.pop()
                for value in new_values:
                    container.add().CopyFrom(value)

        scalar_containers = [
            node.input,
            attr[knInputNames].list.s,
            attr[knInputDtypes].list.type,
            attr[knInputBatchAxis].list.i,
        ]
        for container in scalar_containers:
            maybe_discard_from_scalar_container(container)
        maybe_discard_from_composite_container(attr[knInputShapes].list.shape)
        if knInputShuffles in attr:
            maybe_discard_from_composite_container(attr[knInputShuffles].list.tensor)
        if knInputCanUseShm in attr:
            maybe_discard_from_scalar_container(attr[knInputCanUseShm].list.b)
        node.attr[knGraphDef].s = subgraph_def.SerializeToString()
    return graph_def