def normalize_operators()

in python/graph_def_util.py [0:0]


def normalize_operators(graph_def):
    gd_tensor_name_to_consumers = {}
    gd_tensor_name_to_shape = {}
    for node in graph_def.node:
        for inp in node.input:
            if inp not in gd_tensor_name_to_consumers:
                gd_tensor_name_to_consumers[inp] = []
            gd_tensor_name_to_consumers[inp].append(inp)
        if kOutputShapes in node.attr:
            for idx, shape in enumerate(node.attr[kOutputShapes].list.shape):
                tensor_name = node.name if idx == 0 else '{}:{}'.format(node.name, idx)
                gd_tensor_name_to_shape[tensor_name] = shape
    for node in graph_def.node:
        if node.op == 'StopGradient':
            node.op = 'Identity'
        elif node.op == 'FusedBatchNormV3':  # can be replace by FusedBatchNorm for inference
            if node.attr['T'].type != dtypes.float32.as_datatype_enum:
                continue
            found_training_consumer = False
            for idx in range(3, 6):
                gd_tensor_name = '{}:{}'.format(node.name, idx)
                if gd_tensor_name_to_consumers.get(gd_tensor_name, False):
                    found_training_consumer = True
            if not found_training_consumer:
                node.op = 'FusedBatchNorm'
                node.attr.pop('U')
                if kOutputShapes in node.attr:
                    node.attr[kOutputShapes].list.shape.pop()
        elif node.op == 'AddV2':
            node.op = 'Add'
        elif node.op == 'BatchMatMulV2':  # only change to BatchMatMul if no broadcast
            input0, input1 = node.input[0], node.input[1]
            if input0 not in gd_tensor_name_to_shape:
                continue
            if input1 not in gd_tensor_name_to_shape:
                continue
            shape0 = TensorShape(gd_tensor_name_to_shape[input0])
            shape1 = TensorShape(gd_tensor_name_to_shape[input1])
            if shape0.rank is not None and shape0.rank == shape1.rank and shape0[:-2] == shape1[:-2]:
                node.op = 'BatchMatMul'
    return graph_def