in scripts/parser.py [0:0]
def strip_data(graph, size_limit=1 * 1024 * 1024):
# Remove initializers in the current graph.
del graph.initializer[:]
# Iterate over nodes to process any subgraphs.
for node in graph.node:
for attr in node.attribute:
# If attribute holds a single subgraph.
if attr.type == onnx.AttributeProto.GRAPH:
strip_data(attr.g)
# If attribute holds multiple subgraphs.
elif attr.type == onnx.AttributeProto.GRAPHS:
for subgraph in attr.graphs:
strip_data(subgraph)
if node.op_type == "Constant":
# ONNX Constant nodes store their tensor under attribute 'value'
for attr in node.attribute:
if attr.name == "value" and attr.t is not None:
tp = attr.t
data_size = len(tp.raw_data) if tp.raw_data else 0
if data_size > size_limit:
# Remove all data fields from TensorProto
tp.ClearField("raw_data")
tp.ClearField("float_data")
tp.ClearField("int32_data")
tp.ClearField("int64_data")
tp.ClearField("double_data")
tp.ClearField("uint64_data")
tp.ClearField("string_data")