def simple_bind()

in python/mxnet/symbol.py [0:0]


    def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
                    shared_arg_names=None, shared_exec=None, shared_buffer=None, **kwargs):
        """Bind current symbol to get an executor, allocate all the arguments needed.
        Allows specifying data types.

        This function simplifies the binding procedure. You need to specify only input data shapes.
        Before binding the executor, the function allocates arguments and auxiliary states
        that were not explicitly specified. Allows specifying data types.

        Example usage:
        ----------
        >>> x = mx.sym.Variable('x')
        >>> y = mx.sym.FullyConnected(x, num_hidden=4)
        >>> exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req='null')
        >>> exe.forward()
        [<NDArray 5x4 @cpu(0)>]
        >>> exe.outputs[0].asnumpy()
        array([[ 0.,  0.,  0.,  0.],
               [ 0.,  0.,  0.,  0.],
               [ 0.,  0.,  0.,  0.],
               [ 0.,  0.,  0.,  0.],
               [ 0.,  0.,  0.,  0.]], dtype=float32)
        >>> exe.arg_arrays
        [<NDArray 5x4 @cpu(0)>, <NDArray 4x4 @cpu(0)>, <NDArray 4 @cpu(0)>]
        >>> exe.grad_arrays
        [<NDArray 5x4 @cpu(0)>, <NDArray 4x4 @cpu(0)>, <NDArray 4 @cpu(0)>]

        Parameters
        ----------
        ctx : Context
            The device context the generated executor to run on.

        grad_req: string
            {'write', 'add', 'null'}, or list of str or dict of str to str, optional
            To specify how we should update the gradient to the `args_grad`.

            - 'write' means every time gradient is written to specified `args_grad` NDArray.
            - 'add' means every time gradient is added to the specified NDArray.
            - 'null' means no action is taken, the gradient may not be calculated.

        type_dict  : Dict of str->numpy.dtype
            Input type dictionary, name->dtype

        group2ctx : Dict of string to mx.Context
            The dict mapping the `ctx_group` attribute to the context assignment.

        shared_arg_names : List of string
            The argument names whose `NDArray` of shared_exec can be reused for initializing
            the current executor.

        shared_exec : Executor
            The executor whose arg_arrays, arg_arrays, grad_arrays, and aux_arrays can be
            reused for initializing the current executor.

        shared_buffer : Dict of string to `NDArray`
            The dict mapping argument names to the `NDArray` that can be reused for initializing
            the current executor. This buffer will be checked for reuse if one argument name
            of the current executor is not found in `shared_arg_names`.

        kwargs : Dict of str->shape
            Input shape dictionary, name->shape

        Returns
        -------
        executor : mxnet.Executor
            The generated executor
        """
        num_provided_arg_types = 0
        provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)()  # provided type argument names
        provided_arg_type_data = ctypes.POINTER(mx_uint)()  # provided types
        if type_dict is not None:
            provided_arg_type_names = []
            provided_arg_type_data = []
            for k, v in type_dict.items():
                v = _numpy.dtype(v).type
                if v in _DTYPE_NP_TO_MX:
                    provided_arg_type_names.append(c_str(k))
                    provided_arg_type_data.append(ctypes.c_int(_DTYPE_NP_TO_MX[v]))
            num_provided_arg_types = mx_uint(len(provided_arg_type_names))
            provided_arg_type_names = c_array(ctypes.c_char_p, provided_arg_type_names)
            provided_arg_type_data = c_array(ctypes.c_int, provided_arg_type_data)

        provided_arg_shape_data = []  # shape data
        # argument shape index in sdata,
        # e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg
        provided_arg_shape_idx = [0]
        provided_arg_shape_names = []  # provided argument names
        for k, v in kwargs.items():
            # if k not in listed_arguments and k not in listed_aux_states:
            #   raise ValueError('arg name %s is not valid', k)
            if isinstance(v, tuple):
                provided_arg_shape_names.append(c_str(k))
                provided_arg_shape_data.extend(v)
                provided_arg_shape_idx.append(len(provided_arg_shape_data))

        provided_req_type_list_len = 0
        provided_grad_req_types = ctypes.POINTER(ctypes.c_char_p)()
        provided_grad_req_names = ctypes.POINTER(ctypes.c_char_p)()
        if grad_req is not None:
            if isinstance(grad_req, string_types):
                # use provided_req_type_list_len = 0 to indicate this situation
                provided_req_type_list_len = 0
                provided_grad_req_types = [c_str(grad_req)]
            elif isinstance(grad_req, list):
                if len(grad_req) == 0:
                    raise RuntimeError('grad_req in simple_bind cannot be an empty list')
                provided_grad_req_types = [c_str(item) for item in grad_req]
                provided_req_type_list_len = len(provided_grad_req_types)
            elif isinstance(grad_req, dict):
                if len(grad_req) == 0:
                    raise RuntimeError('grad_req in simple_bind cannot be an empty dict')
                provided_grad_req_names = []
                provided_grad_req_types = []
                for k, v in grad_req.items():
                    provided_grad_req_names.append(c_str(k))
                    provided_grad_req_types.append(c_str(v))
                provided_grad_req_names = c_array(ctypes.c_char_p, provided_grad_req_names)
                provided_req_type_list_len = len(provided_grad_req_types)
            provided_grad_req_types = c_array(ctypes.c_char_p, provided_grad_req_types)

        num_ctx_map_keys = mx_uint(0)
        ctx_map_keys = ctypes.POINTER(ctypes.c_char_p)()
        ctx_map_dev_types = ctypes.POINTER(ctypes.c_int)()
        ctx_map_dev_ids = ctypes.POINTER(ctypes.c_int)()
        if group2ctx is not None:
            ctx_map_keys = []
            ctx_map_dev_types = []
            ctx_map_dev_ids = []
            for key, val in group2ctx.items():
                ctx_map_keys.append(c_str(key))
                ctx_map_dev_types.append(ctypes.c_int(val.device_typeid))
                ctx_map_dev_ids.append(ctypes.c_int(val.device_id))
            num_ctx_map_keys = mx_uint(len(ctx_map_keys))
            ctx_map_keys = c_array(ctypes.c_char_p, ctx_map_keys)
            ctx_map_dev_types = c_array(ctypes.c_int, ctx_map_dev_types)
            ctx_map_dev_ids = c_array(ctypes.c_int, ctx_map_dev_ids)

        # prepare param names
        shared_arg_name_list = []
        if shared_arg_names is not None:
            if not isinstance(shared_arg_names, list):
                raise ValueError('shared_arg_names in simple_bind must be a list or None')
            shared_arg_name_list = [c_str(name) for name in shared_arg_names]

        # prepare shared_buffer
        if shared_buffer is None:
            shared_buffer_len = ctypes.c_int(-1)
            shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)()
            shared_buffer_handles = ctypes.POINTER(NDArrayHandle)()
        else:
            if not isinstance(shared_buffer, dict):
                raise ValueError('shared_buffer in simple_bind must be dict or None')
            shared_buffer_names = []
            shared_buffer_handles = []
            for k, v in shared_buffer.items():
                shared_buffer_names.append(c_str(k))
                shared_buffer_handles.append(v.handle)
            shared_buffer_names = c_array(ctypes.c_char_p, shared_buffer_names)
            shared_buffer_len = ctypes.c_int(len(shared_buffer_handles))
            shared_buffer_handles = c_array(NDArrayHandle, shared_buffer_handles)
        updated_shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)()
        updated_shared_buffer_handles = ctypes.POINTER(NDArrayHandle)()

        # prepare shared_exec_handle
        shared_exec_handle = shared_exec.handle if shared_exec is not None else ExecutorHandle()

        # prepare current executor handle
        exe_handle = ExecutorHandle()

        # prepare current executor's in_args, arg_grads, and aux_states
        num_in_args = ctypes.c_uint()
        in_arg_handles = ctypes.POINTER(NDArrayHandle)()
        arg_grad_handles = ctypes.POINTER(NDArrayHandle)()
        num_aux_states = ctypes.c_uint()
        aux_state_handles = ctypes.POINTER(NDArrayHandle)()

        try:
            check_call(_LIB.MXExecutorSimpleBind(self.handle,
                                                 ctypes.c_int(ctx.device_typeid),
                                                 ctypes.c_int(ctx.device_id),
                                                 num_ctx_map_keys,
                                                 ctx_map_keys,
                                                 ctx_map_dev_types,
                                                 ctx_map_dev_ids,
                                                 mx_uint(provided_req_type_list_len),
                                                 provided_grad_req_names,
                                                 provided_grad_req_types,
                                                 mx_uint(len(provided_arg_shape_names)),
                                                 c_array(ctypes.c_char_p, provided_arg_shape_names),
                                                 c_array(mx_uint, provided_arg_shape_data),
                                                 c_array(mx_uint, provided_arg_shape_idx),
                                                 num_provided_arg_types,
                                                 provided_arg_type_names,
                                                 provided_arg_type_data,
                                                 mx_uint(len(shared_arg_name_list)),
                                                 c_array(ctypes.c_char_p, shared_arg_name_list),
                                                 ctypes.byref(shared_buffer_len),
                                                 shared_buffer_names,
                                                 shared_buffer_handles,
                                                 ctypes.byref(updated_shared_buffer_names),
                                                 ctypes.byref(updated_shared_buffer_handles),
                                                 ctypes.byref(num_in_args),
                                                 ctypes.byref(in_arg_handles),
                                                 ctypes.byref(arg_grad_handles),
                                                 ctypes.byref(num_aux_states),
                                                 ctypes.byref(aux_state_handles),
                                                 shared_exec_handle,
                                                 ctypes.byref(exe_handle)))
        except MXNetError as e:
            error_msg = "simple_bind error. Arguments:\n"
            for k, v in kwargs.items():
                error_msg += "%s: %s\n" % (k, v)
            error_msg += "%s" % e
            raise RuntimeError(error_msg)

        # update shared_buffer
        if shared_buffer is not None:
            for i in range(shared_buffer_len.value):
                k = py_str(updated_shared_buffer_names[i])
                v = NDArray(NDArrayHandle(updated_shared_buffer_handles[i]))
                shared_buffer[k] = v

        # create in_args, arg_grads, and aux_states for the current executor
        arg_arrays = [NDArray(NDArrayHandle(in_arg_handles[i])) for i in range(num_in_args.value)]
        grad_arrays = [NDArray(NDArrayHandle(arg_grad_handles[i]))
                       if arg_grad_handles[i] is not None
                       else None for i in range(num_in_args.value)]
        aux_arrays = [NDArray(NDArrayHandle(aux_state_handles[i]))
                      for i in range(num_aux_states.value)]

        executor = Executor(exe_handle, self, ctx, grad_req, group2ctx)
        executor.arg_arrays = arg_arrays
        executor.grad_arrays = grad_arrays
        executor.aux_arrays = aux_arrays
        return executor