def inference_graph_from_session()

in python/graph_util.py [0:0]


def inference_graph_from_session(
        sess=None, input_tensors=None, output_tensors=None, signature_def=None,
        shape_feed_dict=None, feed_dict=None, dynamic_batch_size=False,
        protected_op_names=None,
        supported_op_types=None, no_fuse_ops=None, force_fuse_ops=None, minimum_segment_size=None,
        grappler=False, max_num_compilers=None,
        compiler_args=None, compiler_workdir=None, compiler_timeout=None, compiler_recovery=True,
        compiler_verbose=None, amp_pass=False):
    """Constructs an inference graph from a tensorflow session.

    Generally decomposes into 5 passes:
        1. Convert all variables to constants, `Assign`s to `Identity`s.
        2. Whitelist-based graph partitioning, each subgraph (wrapped in an `NeuronOp`)
            will contain only operations whose types match the types listed in `supported_op_types`.
        3. Shape inference to find shapes for input/output tensors of `NeuronOp` subgraphs.
        4. Call neuron-cc compiler on each `NeuronOp`.
        5. Restore `NeuronOp`s that are failed to compile into their original form.

    Args:
        sess: Active TensorFlow session.
        input_tensors: None or iterable of strings/tensors (unordered). Strings should be
            tensor names. Setting this argument can help when inference starts from some
            arbitrary tensors that are not placeholder tensors.
        output_tensors: None or iterable of strings/tensors (unordered). Strings should be
            tensor names.
        signature_def: None or a `SignatureDef` protobuf message marking graph inputs and outputs.
        shape_feed_dict: Dict `{str: shape}` used by `shape_inference`.
        feed_dict: Dict `{str: numpy.ndarray}` used by `shape_inference_with_inputs`.
            Optional. If both `shape_feed_dict` and `feed_dict` are unspecified, no shape
            inference will be performed. If only `shape_feed_dict` is specified, will perform
            `shape_inference` only. As long as `feed_dict` is specified, will perform
            `shape_inference` first and then `shape_inference_with_inputs`.
        dynamic_batch_size: Bool that represents whether the inference graph will support
            dynamic batch sizes during inference.
        supported_op_types: Iterable of strings (unordered) representing compilable op names.
        no_fuse_ops: None or iterable of strings (unordered) representing names of ops
            that are forcibly placed on CPU.
        force_fuse_ops: None or iterable of strings (unordered) representing names of ops
            that are forcibly sent to the neuron-cc compiler.
        minimum_segment_size: Integer; minimum number of ops in an `NeuronOp` used by
            `whitelist_partition`.
        max_num_compilers: Integer representing maximum allowed compiler processes.
        compiler_args: List of strings representing compiler arguments. Note that these
            arguments will be applied to all subgraphs generated by whitelist partitioning.
        compiler_workdir: Str representing work directory of the neuron-cc compiler.
        compiler_timeout: Integer representing maximum allowed runtime for the neuron-cc compiler.
        compiler_recovery: Bool representing whether to recovery from neuron-cc compiler failure.

    Returns:
        A `Graph` object that is optimized for running inference on Inferentia.

    Note:
        `input_tensors`, `shape_feed_dict`, and `feed_dict` can all set input tensors, and so
        the latter one will always override the former one.
    """
    if 'NEURON_CC_FLAGS' in os.environ:
        parser = argparse.ArgumentParser()
        parser.add_argument('--must-compile', action='store_true')
        parser.add_argument('--dump-prefix', default=None)
        parser.add_argument('--verbose', type=int, default=None)
        tf_neuron_args, neuron_cc_args = parser.parse_known_args(shlex.split(os.environ['NEURON_CC_FLAGS']))
        if tf_neuron_args.verbose is not None:
            compiler_verbose = tf_neuron_args.verbose
        if tf_neuron_args.must_compile:
            compiler_recovery = False
            if compiler_verbose is None:
                compiler_verbose = 1
            logging.warning('Enabling must-compile according to NEURON_CC_FLAGS environment variable; '
                            'neuron-cc failures will be thrown as exceptions')
        if tf_neuron_args.dump_prefix is not None:
            compiler_workdir = tf_neuron_args.dump_prefix
        if neuron_cc_args:
            if compiler_args is None:
                compiler_args = neuron_cc_args
            else:
                compiler_args.extend(neuron_cc_args)
    if sess is None:
        sess = ops.get_default_session()
    if feed_dict is not None:
        feed_dict = {getattr(ts, 'name', ts): value for ts, value in feed_dict.items()}
    if shape_feed_dict is not None:
        shape_feed_dict = {getattr(ts, 'name', ts): value for ts, value in shape_feed_dict.items()}
    if signature_def is None:
        # build a SignatureDef from input/output tensors
        if input_tensors is None:
            if feed_dict is not None:
                input_names = feed_dict.keys()
            elif shape_feed_dict is not None:
                input_names = shape_feed_dict.keys()
            else:
                input_names = [op.outputs[0].name for op in sess.graph.get_operations()
                                                  if op.type == 'Placeholder']
        else:
            input_names = [getattr(ts, 'name', ts) for ts in input_tensors]
        input_tensors = [sess.graph.get_tensor_by_name(name) for name in input_names]
        if output_tensors is None:
            output_ops = [op for op in sess.graph.get_operations()
                             if all(not ts.consumers() for ts in op.outputs)]
            output_names = [ts.name for op in output_ops for ts in op.outputs]
        else:
            output_names = [getattr(ts, 'name', ts) for ts in output_tensors]
        output_tensors = [sess.graph.get_tensor_by_name(name) for name in output_names]
        signature_def = mgu.build_signature_def(input_tensors, output_tensors)

    # convert variables to constants
    if protected_op_names is None:
        protected_op_names = set()
    protected_op_names = set(protected_op_names)
    io_infos = itertools.chain(signature_def.inputs.values(), signature_def.outputs.values())
    protected_op_names.update(sess.graph.get_tensor_by_name(info.name).op.name for info in io_infos)
    if feed_dict is not None:
        protected_op_names.update(sess.graph.get_tensor_by_name(name).op.name
                                  for name in feed_dict.keys())
        if shape_feed_dict is None:
            key_dict = {key: key for key in feed_dict}
            evaluated_feed_dict = sess.run(key_dict, feed_dict)
            shape_feed_dict = {key: value.shape for key, value in evaluated_feed_dict.items()}
    if shape_feed_dict is not None:
        protected_op_names.update(sess.graph.get_tensor_by_name(name).op.name
                                  for name in shape_feed_dict.keys())

    if grappler:
        with sess.graph.as_default():
            rewriter_config = rewriter_config_pb2.RewriterConfig()
            opt_config = config_pb2.ConfigProto()
            opt_config.graph_options.rewrite_options.CopyFrom(rewriter_config)
            train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
            train_op.extend(sess.graph.get_tensor_by_name(name) for name in output_names)
            grappler_metagraph = meta_graph.create_meta_graph_def(graph=sess.graph)
            graph_def = tf_optimizer.OptimizeGraph(opt_config, grappler_metagraph)
    else:
        graph_def = sess.graph.as_graph_def(add_shapes=True)

    # convert all variables to constants
    with replace_extract_sub_graph():
        graph_def = tf_graph_util.convert_variables_to_constants.__wrapped__(
            sess, graph_def, list(protected_op_names))
    for node in graph_def.node:
        if node.op == 'RefEnter':
            node.op = 'Enter'
        elif node.op == 'RefExit':
            node.op = 'Exit'
    original_graph_def = graph_def

    # setup op exclusions
    no_fuse_ops = set() if no_fuse_ops is None else set(no_fuse_ops)
    control_op_names = [node.name for node in graph_def.node if _has_control_input(node)]

    # exclude ops with control outputs
    no_fuse_ops.update(control_op_names)

    # exclude ops that are attached to string tensors
    name_to_consumers = {}
    for node in graph_def.node:
        for inp in node.input:
            input_node_name = inp[:inp.index(':')] if ':' in inp else inp
            if input_node_name not in name_to_consumers:
                name_to_consumers[input_node_name] = set()
            name_to_consumers[input_node_name].add(node.name)
    for node in graph_def.node:
        if 'T' in node.attr:
            if node.attr['T'].type == dtypes.string.as_datatype_enum:
                no_fuse_ops.add(node.name)
                no_fuse_ops.update(name_to_consumers.get(node.name, []))
        if 'dtype' in node.attr:
            if node.attr['dtype'].type == dtypes.string.as_datatype_enum:
                no_fuse_ops.add(node.name)
                no_fuse_ops.update(name_to_consumers.get(node.name, []))

    # normalize operators
    graph_def = gdu.normalize_operators(graph_def)

    # initialize inferred shapes
    graph_def = gdu.encode_inferred_shapes(graph_def, shape_feed_dict)

    # Adding the auto-mixed precision pass
    if amp_pass:
        graph_def = amp_optimization(graph_def, signature_def)

    # fuse ops into `NeuronOp`'s and determine tensors that require shapes
    part_graph_def = whitelist_partition(
        graph_def, signature_def, supported_op_types=supported_op_types,
        no_fuse_ops=no_fuse_ops, force_fuse_ops=force_fuse_ops,
        minimum_segment_size=minimum_segment_size)

    # perform an inference to find tensor shapes as a last resort
    # todo: change to hard_shape_inference == True
    if feed_dict is not None:
        part_graph_def = gdu.shape_inference_with_inputs(part_graph_def, sess, feed_dict)

    # call compiler for each `NeuronOp`
    args_dict = {}
    if compiler_args is not None:
        args_dict = {node.name: compiler_args for node in gdu.get_neuron_nodes(part_graph_def)}
    compiled_graph_def = compile_subgraphs(
        part_graph_def, workdir=compiler_workdir,
        args_dict=args_dict, timeout=compiler_timeout, max_num_compilers=max_num_compilers,
        verbose=compiler_verbose)

    if dynamic_batch_size:
        compiled_graph_def = mark_batch_axis(compiled_graph_def)

    if compiler_recovery:
        compiled_graph_def = gdu.restore_compiler_failures(compiled_graph_def, original_graph_def)
        compiled_graph_def = gdu.run_graph_def_pass_in_subgraphs(compiled_graph_def, gdu.erase_large_constants)
        compiled_graph_def = nchw_to_nhwc(compiled_graph_def)

    # try to enable dynamic batch size if possible
    if not dynamic_batch_size:
        compiled_graph_def, dynamic_batch_size = set_dynamic_batch_size(compiled_graph_def)

    # rename NeuronOp's for better visualization
    compiled_graph_def = gdu.prefix_node_names(compiled_graph_def)

    # raise exception if NeuronOp is still uncompiled after fallback pass
    uncompiled_node_names = []
    for node in gdu.get_neuron_nodes(compiled_graph_def):
        if not node.attr['executable'].s:
            uncompiled_node_names.append(node.name)
    if uncompiled_node_names:
        raise ValueError('The following subgraphs failed to compile: {}'.format(uncompiled_node_names))

    # execution plan analysis
    compiled_graph_def = gdu.set_execution_plan(compiled_graph_def)

    # return a new graph
    compiled_graph = _graph_def_to_graph(compiled_graph_def)

    # statistics on number of operations
    num_ops_original = len(sess.graph.get_operations())
    num_ops_tfn, num_ops_on_neuron = gdu.compiled_graph_op_counts(compiled_graph_def)
    with utils.logging_show_info():
        logging.info('Number of operations in TensorFlow session: {}'.format(num_ops_original))
        logging.info('Number of operations after tf.neuron optimizations: {}'.format(num_ops_tfn))
        logging.info('Number of operations placed on Neuron runtime: {}'.format(num_ops_on_neuron))
    if ncc.find_neuron_cc() is None:
        logging.warning('***************************************************************')
        logging.warning('')
        logging.warning('  neuron-cc is not found.')
        logging.warning('')
        logging.warning('  To fully optimize TensorFlow model with AWS Neuron, please')
        logging.warning('')
        logging.warning('  install the neuron-cc compiler by "pip install neuron-cc".')
        logging.warning('')
        logging.warning('***************************************************************')
    return compiled_graph