def fuse()

in python/fuse.py [0:0]


def fuse(func=None, *, compiler_args=None, name=None, asynchronous=True, timeout=None,
         verbose=0, workdir=None, input_shapes=None, output_shapes=None,
         batch_size=None, dynamic_batch_size=False, executable=b'', grad_func=None):
    if func is None:
        return partial(
            fuse, compiler_args=compiler_args, name=name, asynchronous=asynchronous, timeout=timeout,
            verbose=verbose, workdir=workdir, input_shapes=input_shapes, output_shapes=output_shapes,
            batch_size=batch_size, dynamic_batch_size=dynamic_batch_size, executable=executable,
            grad_func=grad_func)
    @wraps(func)
    def wrapper(*args, **kwargs):
        eager = executing_eagerly()
        if eager:
            is_asynchronous = False
            ops.disable_eager_execution()
        else:
            is_asynchronous = asynchronous

        # if this is a fused gradient, then change NeuronOp's inputs into something hackable
        is_gradient = False
        if args and _is_neuron_op(args[0]) and func in _neuron_grad_func_set:
            input_list = args[0].inputs
            new_input_list = list(input_list)
            args[0]._inputs_val = new_input_list
            is_gradient = True

        inputs_mgr = TensorManager(is_gradient)
        inputs_mgr.track((args, kwargs))

        default_graph = ops.get_default_graph()
        if default_graph not in _neuron_graph_to_hook:
            _neuron_graph_to_hook[default_graph] = NeuronGraphHook(default_graph)
        graph_hook = _neuron_graph_to_hook[default_graph]

        fuse_graph = ops.Graph()
        with graph_hook.fuse_graph_scope() as latest_fg_var_list:
            with fuse_graph.as_default():
                inputs_mgr.build_placeholder_mapping()
                new_args, new_kwargs = inputs_mgr.build((args, kwargs))
                func_outputs = func(*new_args, **new_kwargs)

        # restore NeuronOp's hacked inputs
        if is_gradient:
            args[0]._inputs_val = input_list
            inputs_mgr.is_gradient = False

        input_tensors = inputs_mgr.tensors()
        placeholders = inputs_mgr.mapped_tensors()
        inputs_mgr.mapping = {value: key for key, value in inputs_mgr.mapping.items()}
        inputs_mgr.build((args, kwargs))
        if name is None:
            all_op_names = [op.name for op in fuse_graph.get_operations()
                            if op.name not in inputs_mgr.new_op_names]
            op_name = utils.most_popular_namescope(all_op_names)
        else:
            op_name = name
        outputs_mgr = TensorManager()
        outputs_mgr.track(func_outputs)
        outputs = outputs_mgr.tensors()
        if dynamic_batch_size:
            input_batch_axis = _dynamic_batch_size_axis(placeholders)
            output_batch_axis = _dynamic_batch_size_axis(outputs)
            if dynamic_batch_size == 'force':
                output_batch_axis = [0 for _ in outputs]  # todo: infer from graph + placeholders
        else:
            input_batch_axis = [-1 for _ in placeholders]
            output_batch_axis = [-1 for _ in outputs]
        if input_shapes is not None:
            for ts, shape in zip(placeholders, input_shapes):
                ts.set_shape(shape)
        if output_shapes is not None:
            for ts, shape in zip(outputs, output_shapes):
                ts.set_shape(shape)
        if batch_size is not None:
            for ts in placeholders:
                if ts.shape.rank:
                    shape = ts.shape.as_list()
                    if shape[0] is None:
                        shape[0] = batch_size
                        ts.set_shape(shape)
            for ts in outputs:
                if ts.shape.rank:
                    shape = ts.shape.as_list()
                    if shape[0] is None:
                        shape[0] = batch_size
                        ts.set_shape(shape)
        io_config = _io_config(placeholders, outputs)
        neuron_get_cc_job_func = partial(
            graph_hook.neuron_get_cc_job, fuse_graph, latest_fg_var_list,
            workdir=workdir, io_config=io_config, compiler_args=compiler_args,
            verbose=verbose, timeout=timeout, op_name=op_name)
        executable_content = executable
        if not executable_content and not is_asynchronous:
            neuron_cc_job, neff_path = neuron_get_cc_job_func()
            neuron_cc_job()
            with open(neff_path, 'rb') as f:
                executable_content = f.read()
        model_config = neff_util.get_model_config(executable_content)
        if eager:
            # hack to allow enable_eager_execution; see tensorflow/python/framework/ops.py
            global_default_graph = ops._default_graph_stack._global_default_graph
            ops._default_graph_stack._global_default_graph = None
            ops.enable_eager_execution()
            ops._default_graph_stack._global_default_graph = global_default_graph
        fuse_graph_def = fuse_graph.as_graph_def()
        erase_large_constants(fuse_graph_def)
        with ops.name_scope(op_name):
            output_tensors = neuron_op(
                input_tensors=input_tensors, graph_def=fuse_graph_def.SerializeToString(),
                input_names=[ts.name for ts in placeholders],
                input_shapes=[ts.shape for ts in placeholders],
                input_batch_axis=input_batch_axis,
                output_names=[ts.name for ts in outputs],
                output_dtypes=[ts.dtype for ts in outputs],
                output_shapes=[ts.shape for ts in outputs],
                output_batch_axis=output_batch_axis,
                executable=executable_content,
                model_config=model_config,
            )
        if is_asynchronous and not executable_content:
            graph_hook.map_cc_job_func[output_tensors[0].op] = neuron_get_cc_job_func
            global _neuron_sess_run_decorated
            if not _neuron_sess_run_decorated:
                session.Session.run = neuron_decorate_run(session.Session.run)
                _neuron_sess_run_decorated = True
        if callable(grad_func):
            _neuron_grad_dict[output_tensors[0]] = grad_func
            _neuron_grad_func_set.add(getattr(grad_func, '__wrapped__', grad_func))
        outputs_mgr.mapping = {inner: outer for inner, outer in zip(outputs, output_tensors)}
        return outputs_mgr.build(func_outputs)
    return wrapper