in python/graph_def_util.py [0:0]
def inline_shape_inputs_in_subgraphs(graph_def):
"""
If NeuronOp has inputs that come from shape-related operators, then inline
them as constants in the subgraph and remove them from the input signature.
Note that the current approach only deals with inputs that come from
shape-related operators directly. In theory, a safer approach is to copy
the entire graph, turn all inferrable shapes into constants, and then
propagate through constant folding. It is not practical at this point as
copying the entire graph would consume too much memory and it will become
practical once we don't have to freeze the entire graph.
"""
name_to_node = {node.name: node for node in graph_def.node}
shape_content_fn_map = {'Shape': TensorShape.as_list, 'Size': TensorShape.num_elements}
def get_node(name):
node_name, _ = split_tensor_name(name)
return name_to_node[node_name]
def contains_shape_input(node):
# Returns True if any non-control input node is a shape-related operator
return any(get_node(name).op in shape_content_fn_map for name in node.input if not name.startswith('^'))
if not any(contains_shape_input(node) for node in get_neuron_nodes(graph_def)):
return graph_def
for node in get_neuron_nodes(graph_def):
subgraph_def = get_subgraph_def(node)
subgraph_name_to_node = {sn.name: sn for sn in subgraph_def.node}
attr = node.attr
discards = set()
for idx, (input_name, ph_name) in enumerate(zip(node.input, attr[knInputNames].list.s)):
input_node = get_node(input_name)
if input_node.op in shape_content_fn_map:
shape_input_name, = input_node.input
shape_input_node_name, port = split_tensor_name(shape_input_name)
shape_input_node = name_to_node[shape_input_node_name]
shape_attr = shape_input_node.attr.get(kNeuronInferredShapes, None)
if shape_attr is None:
shape_attr = shape_input_node.attr.get(knOutputShapes, None)
if shape_attr is None:
continue
shape_proto = shape_attr.list.shape[port]
shape = TensorShape(shape_proto)
dtype_enum = input_node.attr['out_type'].type
dtype = dtypes.as_dtype(dtype_enum)
tensor_content = shape_content_fn_map[input_node.op](shape)
shape_tensor = convert_to_tensor(tensor_content, dtype)
ph_node_name, _ = split_tensor_name(ph_name.decode())
ph_node = subgraph_name_to_node[ph_node_name]
ph_node.attr['dtype'].type = dtype_enum
ph_node.attr.pop('shape')
tensor_proto = ph_node.attr['value'].tensor
tensor_proto.dtype = dtype_enum
tensor_proto.tensor_shape.CopyFrom(shape_tensor.shape.as_proto())
tensor_proto.tensor_content = shape_tensor.numpy().tobytes()
ph_node.op = 'Const'
discards.add(idx)
if not discards:
continue
def maybe_discard_from_scalar_container(container):
if container:
container[:] = [value for idx, value in enumerate(container) if idx not in discards]
def maybe_discard_from_composite_container(container):
if container:
new_values = [value for idx, value in enumerate(container) if idx not in discards]
while container:
container.pop()
for value in new_values:
container.add().CopyFrom(value)
scalar_containers = [
node.input,
attr[knInputNames].list.s,
attr[knInputDtypes].list.type,
attr[knInputBatchAxis].list.i,
]
for container in scalar_containers:
maybe_discard_from_scalar_container(container)
maybe_discard_from_composite_container(attr[knInputShapes].list.shape)
if knInputShuffles in attr:
maybe_discard_from_composite_container(attr[knInputShuffles].list.tensor)
if knInputCanUseShm in attr:
maybe_discard_from_scalar_container(attr[knInputCanUseShm].list.b)
node.attr[knGraphDef].s = subgraph_def.SerializeToString()
return graph_def