in python/graph_def_util.py [0:0]
def normalize_operators(graph_def):
gd_tensor_name_to_consumers = {}
gd_tensor_name_to_shape = {}
for node in graph_def.node:
for inp in node.input:
if inp not in gd_tensor_name_to_consumers:
gd_tensor_name_to_consumers[inp] = []
gd_tensor_name_to_consumers[inp].append(inp)
if kOutputShapes in node.attr:
for idx, shape in enumerate(node.attr[kOutputShapes].list.shape):
tensor_name = node.name if idx == 0 else '{}:{}'.format(node.name, idx)
gd_tensor_name_to_shape[tensor_name] = shape
for node in graph_def.node:
if node.op == 'StopGradient':
node.op = 'Identity'
elif node.op == 'FusedBatchNormV3': # can be replace by FusedBatchNorm for inference
if node.attr['T'].type != dtypes.float32.as_datatype_enum:
continue
found_training_consumer = False
for idx in range(3, 6):
gd_tensor_name = '{}:{}'.format(node.name, idx)
if gd_tensor_name_to_consumers.get(gd_tensor_name, False):
found_training_consumer = True
if not found_training_consumer:
node.op = 'FusedBatchNorm'
node.attr.pop('U')
if kOutputShapes in node.attr:
node.attr[kOutputShapes].list.shape.pop()
elif node.op == 'AddV2':
node.op = 'Add'
elif node.op == 'BatchMatMulV2': # only change to BatchMatMul if no broadcast
input0, input1 = node.input[0], node.input[1]
if input0 not in gd_tensor_name_to_shape:
continue
if input1 not in gd_tensor_name_to_shape:
continue
shape0 = TensorShape(gd_tensor_name_to_shape[input0])
shape1 = TensorShape(gd_tensor_name_to_shape[input1])
if shape0.rank is not None and shape0.rank == shape1.rank and shape0[:-2] == shape1[:-2]:
node.op = 'BatchMatMul'
return graph_def