perl-package/AI-MXNetCAPI/mxnet.i (1,573 lines of code) (raw):

%module "AI::MXNetCAPI" %rename("%(strip:[MX])s") ""; %include typemaps.i %include mxnet_typemaps.i %inline %{ #include <c_api.h> // Taken as is from http://cpansearch.perl.org/src/COLEMINOR/Games-EternalLands-Binary-Float16-0.01/Float16.xs /* This method is faster than the OpenEXR implementation (very often * used, eg. in Ogre), with the additional benefit of rounding, inspired * by James Tursa's half-precision code. */ static inline uint16_t _float_to_half(uint32_t x) { uint16_t bits = (x >> 16) & 0x8000; uint16_t m = (x >> 12) & 0x07ff; unsigned int e = (x >> 23) & 0xff; if (e < 103) return bits; if (e > 142) { bits |= 0x7c00u; bits |= e == 255 && (x & 0x007fffffu); return bits; } if (e < 113) { m |= 0x0800u; bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1); return bits; } bits |= ((e - 112) << 10) | (m >> 1); bits += m & 1; return bits; } static int const shifttable[32] = { 23, 14, 22, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 20, 0, 15, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 17, 0, 18, 19, 0, }; static uint32_t const shiftmagic = 0x07c4acddu; /* This algorithm is similar to the OpenEXR implementation, except it * uses branchless code in the denormal path. This is slower than a * table version, but will be more friendly to the cache for occasional * uses. */ static inline uint32_t _half_to_float(uint16_t x) { uint32_t s = (x & 0x8000u) << 16; if ((x & 0x7fffu) == 0) return (uint32_t)x << 16; uint32_t e = x & 0x7c00u; uint32_t m = x & 0x03ffu; if (e == 0) { uint32_t v = m | (m >> 1); v |= v >> 2; v |= v >> 4; v |= v >> 8; e = shifttable[(v * shiftmagic) >> 27]; return s | (((125 - e) << 23) + (m << e)); } if (e == 0x7c00u) { if (m == 0) return s | 0x7f800000u; return s | 0x7fc00000u; } return s | (((e >> 10) + 112) << 23) | (m << 13); } union fbits { float f; uint32_t x; }; static void KVStore_callback(int index, NDArrayHandle recv, NDArrayHandle local, void* callback) { { dSP; PUSHMARK(SP); XPUSHs(sv_2mortal(newSViv(index))); XPUSHs(SWIG_NewPointerObj(SWIG_as_voidptr(recv), SWIGTYPE_p_MXNDArray, 0)); XPUSHs(SWIG_NewPointerObj(SWIG_as_voidptr(local), SWIGTYPE_p_MXNDArray, 0)); PUTBACK; call_sv((SV*)callback, G_DISCARD); } } static void KVStoreServer_callback(int head, const char *body, void* callback) { { dSP; PUSHMARK(SP); XPUSHs(sv_2mortal(newSViv(head))); XPUSHs(sv_2mortal(newSVpv(body, 0))); PUTBACK; call_sv((SV*)callback, G_DISCARD); } } static void ExecutorMonitor_callback(const char* name, NDArrayHandle handle, void* callback) { { dSP; PUSHMARK(SP); XPUSHs(sv_2mortal(newSVpv(name, 0))); XPUSHs(SWIG_NewPointerObj(SWIG_as_voidptr(handle), SWIGTYPE_p_MXNDArray, 0)); PUTBACK; call_sv((SV*)callback, G_DISCARD); } } %} %init %{ /* These SWIG_TypeClientData() calls might break in the future, but * %rename should work on these types before that happens. */ SWIG_TypeClientData(SWIGTYPE_p_MXNDArray, (void *)"NDArrayHandle"); SWIG_TypeClientData(SWIGTYPE_p_MXFunction, (void *)"FunctionHandle"); SWIG_TypeClientData(SWIGTYPE_p_MXAtomicSymbolCreator, (void *)"AtomicSymbolCreator"); SWIG_TypeClientData(SWIGTYPE_p_MXSymbol, (void *)"SymbolHandle"); SWIG_TypeClientData(SWIGTYPE_p_MXExecutor, (void *)"ExecutorHandle"); SWIG_TypeClientData(SWIGTYPE_p_MXDataIterCreator, (void *)"DataIterCreator"); SWIG_TypeClientData(SWIGTYPE_p_MXDataIter, (void *)"DataIterHandle"); SWIG_TypeClientData(SWIGTYPE_p_MXKVStore, (void *)"KVStoreHandle"); SWIG_TypeClientData(SWIGTYPE_p_MXRecordIO, (void *)"RecordIOHandle"); SWIG_TypeClientData(SWIGTYPE_p_MXRtc, (void *)"RtcHandle"); SWIG_TypeClientData(SWIGTYPE_p_MXCachedOp, (void *)"CachedOpHandle"); %} /*! \brief manually define unsigned int */ typedef unsigned int mx_uint; /*! \brief manually define float */ typedef float mx_float; // all the handles are simply void * // will be casted internally to specific pointers types // these typedefs are mainly used for readablity reasons /*! \brief handle to NDArray */ typedef MXNDArray *NDArrayHandle; /*! \brief handle to a mxnet ndarray function that changes NDArray */ typedef MXFunction *FunctionHandle; /*! \brief handle to a function that takes param and creates symbol */ typedef MXAtomicSymbolCreator *AtomicSymbolCreator; /*! \brief handle to a symbol that can be bind as operator */ typedef MXSymbol *SymbolHandle; /*! \brief handle to a AtomicSymbol */ typedef MXAtomicSymbol *AtomicSymbolHandle; /*! \brief handle to an Executor */ typedef MXExecutor *ExecutorHandle; /*! \brief handle a dataiter creator */ typedef MXDataIterCreator *DataIterCreator; /*! \brief handle to a DataIterator */ typedef MXDataIter *DataIterHandle; /*! \brief handle to KVStore */ typedef MXKVStore *KVStoreHandle; /*! \brief handle to RecordIO */ typedef MXRecordIO *RecordIOHandle; /*! \brief handle to MXRtc*/ typedef MXRtc *RtcHandle; /*! \brief handle to cached operator */ typedef MXCachedOp *CachedOpHandle; typedef void (*ExecutorMonitorCallback)(const char*, NDArrayHandle, void *); struct NativeOpInfo { void (*forward)(int, float**, int*, unsigned**, int*, void*); void (*backward)(int, float**, int*, unsigned**, int*, void*); void (*infer_shape)(int, int*, unsigned**, void*); void (*list_outputs)(char***, void*); void (*list_arguments)(char***, void*); // all functions also pass a payload void* pointer void* p_forward; void* p_backward; void* p_infer_shape; void* p_list_outputs; void* p_list_arguments; }; struct NDArrayOpInfo { bool (*forward)(int, void**, int*, void*); bool (*backward)(int, void**, int*, void*); bool (*infer_shape)(int, int*, unsigned**, void*); bool (*list_outputs)(char***, void*); bool (*list_arguments)(char***, void*); bool (*declare_backward_dependency)(const int*, const int*, const int*, int*, int**, void*); // all functions also pass a payload void* pointer void* p_forward; void* p_backward; void* p_infer_shape; void* p_list_outputs; void* p_list_arguments; void* p_declare_backward_dependency; }; /*! * \brief return str message of the last error * all function in this file will return 0 when success * and -1 when an error occured, * MXGetLastError can be called to retrieve the error * * this function is threadsafe and can be called by different thread * \return error info */ const char *MXGetLastError(); //------------------------------------- // Part 0: Global State setups //------------------------------------- /*! * \brief Seed the global random number generators in mxnet. * \param seed the random number seed. * \return 0 when success, -1 when failure happens. */ int MXRandomSeed(int seed); /*! * \brief Notify the engine about a shutdown, * This can help engine to print less messages into display. * * User do not have to call this function. * \return 0 when success, -1 when failure happens. */ int MXNotifyShutdown(); /*! * \brief Set up configuration of profiler * \param mode indicate the working mode of profiler, * record anly symbolic operator when mode == 0, * record all operator when mode == 1 * \param filename where to save trace file * \return 0 when success, -1 when failure happens. */ int MXSetProfilerConfig(int mode, const char* filename); /*! * \brief Set up state of profiler * \param state indicate the working state of profiler, * profiler not running when state == 0, * profiler running when state == 1 * \return 0 when success, -1 when failure happens. */ int MXSetProfilerState(int state); /*! \brief Save profile and stop profiler */ int MXDumpProfile(); /*! \brief Set the number of OMP threads to use */ int MXSetNumOMPThreads(int thread_num); //------------------------------------- // Part 1: NDArray creation and deletion //------------------------------------- /*! * \brief create a NDArray handle that is not initialized * can be used to pass in as mutate variables * to hold the result of NDArray * \param out the returning handle * \return 0 when success, -1 when failure happens */ int MXNDArrayCreateNone(NDArrayHandle *out); /*! * \brief create a NDArray with specified shape * \param shape the pointer to the shape * \param ndim the dimension of the shape * \param dev_type device type, specify device we want to take * \param dev_id the device id of the specific device * \param delay_alloc whether to delay allocation until * the ndarray is first mutated * \param out the returning handle * \return 0 when success, -1 when failure happens */ int MXNDArrayCreate(const mx_uint *in, mx_uint ndim, int dev_type, int dev_id, int delay_alloc, NDArrayHandle *out); /*! * \brief create a NDArray with specified shape and data type * \param shape the pointer to the shape * \param ndim the dimension of the shape * \param dev_type device type, specify device we want to take * \param dev_id the device id of the specific device * \param delay_alloc whether to delay allocation until * the ndarray is first mutated * \param dtype data type of created array * \param out the returning handle * \return 0 when success, -1 when failure happens */ int MXNDArrayCreateEx(const mx_uint *in, mx_uint ndim, int dev_type, int dev_id, int delay_alloc, int dtype, NDArrayHandle *out); /*! * \brief create a NDArray handle that is loaded from raw bytes. * \param buf the head of the raw bytes * \param size size of the raw bytes * \param out the returning handle * \return 0 when success, -1 when failure happens */ int MXNDArrayLoadFromRawBytes(const void *in, size_t size, NDArrayHandle *out); /*! * \brief save the NDArray into raw bytes. * \param handle the NDArray handle * \param out_size size of the raw bytes * \param out_buf the head of returning memory bytes. * \return 0 when success, -1 when failure happens */ int MXNDArraySaveRawBytes(NDArrayHandle handle, size_t *out_size, const char **out_array); /*! * \brief Save list of ndarray into the file. * \param fname name of the file. * \param num_args number of arguments to save. * \param args the array of NDArrayHandles to be saved. * \param keys the name of the NDArray, optional, can be NULL * \return 0 when success, -1 when failure happens */ int MXNDArraySave(const char* fname, mx_uint num_args, NDArrayHandle* in, const char** in); /*! * \brief Load list of ndarray from the file. * \param fname name of the file. * \param out_size number of ndarray loaded. * \param out_arr head of the returning ndarray handles. * \param out_name_size size of output name arrray. * \param out_names the names of returning NDArrays, can be NULL * \return 0 when success, -1 when failure happens */ int MXNDArrayLoad(const char* fname, mx_uint *out_size, NDArrayHandle** out_array, mx_uint *out_size, const char*** out_array); /*! * \brief Perform a synchronize copy from a continugous CPU memory region. * * This function will call WaitToWrite before the copy is performed. * This is useful to copy data from existing memory region that are * not wrapped by NDArray(thus dependency not being tracked). * * \param handle the NDArray handle * \param data the data source to copy from. * \param size the memory size we want to copy from. */ int MXNDArraySyncCopyFromCPU(NDArrayHandle handle, const void *in, size_t size); /*! * \brief Perform a synchronize copy to a continugous CPU memory region. * * This function will call WaitToRead before the copy is performed. * This is useful to copy data from existing memory region that are * not wrapped by NDArray(thus dependency not being tracked). * * \param handle the NDArray handle * \param data the data source to copy into. * \param size the memory size we want to copy into. */ int MXNDArraySyncCopyToCPU(NDArrayHandle handle, void *in, size_t size); /*! * \brief Wait until all the pending writes with respect NDArray are finished. * Always call this before read data out synchronizely. * \param handle the NDArray handle * \return 0 when success, -1 when failure happens */ int MXNDArrayWaitToRead(NDArrayHandle handle); /*! * \brief Wait until all the pending read/write with respect NDArray are finished. * Always call this before write data into NDArray synchronizely. * \param handle the NDArray handle * \return 0 when success, -1 when failure happens */ int MXNDArrayWaitToWrite(NDArrayHandle handle); /*! * \brief wait until all delayed operations in * the system is completed * \return 0 when success, -1 when failure happens */ int MXNDArrayWaitAll(); /*! * \brief free the ndarray handle * \param handle the handle to be freed * \return 0 when success, -1 when failure happens */ int MXNDArrayFree(NDArrayHandle handle); /*! * \brief Slice the NDArray along axis 0. * \param handle the handle to the NDArray * \param slice_begin The beginning index of slice * \param slice_end The ending index of slice * \param out The NDArrayHandle of sliced NDArray * \return 0 when success, -1 when failure happens */ int MXNDArraySlice(NDArrayHandle handle, mx_uint slice_begin, mx_uint slice_end, NDArrayHandle *out); /*! * \brief Index the NDArray along axis 0. * \param handle the handle to the NDArray * \param idx the index * \param out The NDArrayHandle of output NDArray * \return 0 when success, -1 when failure happens */ int MXNDArrayAt(NDArrayHandle handle, mx_uint idx, NDArrayHandle *out); /*! * \brief Reshape the NDArray. * \param handle the handle to the ndarray * \param ndim number of dimensions of new shape * \param dims new shape * \param out the NDArrayHandle of reshaped NDArray * \return 0 when success, -1 when failure happens */ int MXNDArrayReshape(NDArrayHandle handle, int ndim, int *in, NDArrayHandle *out); /*! * \brief get the shape of the array * \param handle the handle to the ndarray * \param out_dim the output dimension * \param out_pdata pointer holder to get data pointer of the shape * \return 0 when success, -1 when failure happens */ int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata); /*! * \brief get the content of the data in NDArray * \param handle the handle to the ndarray * \param out_pdata pointer holder to get pointer of data * \return 0 when success, -1 when failure happens */ int MXNDArrayGetData(NDArrayHandle handle, void **out_pdata); /*! * \brief get the type of the data in NDArray * \param handle the handle to the ndarray * \param out_dtype pointer holder to get type of data * \return 0 when success, -1 when failure happens */ int MXNDArrayGetDType(NDArrayHandle handle, int *out); /*! * \brief get the context of the NDArray * \param handle the handle to the ndarray * \param out_dev_type the output device type * \param out_dev_id the output device id * \return 0 when success, -1 when failure happens */ int MXNDArrayGetContext(NDArrayHandle handle, int *out, int *out); /*! * \brief detach and ndarray from computation graph by clearing entry_ * \param handle NDArray handle * \return 0 when success, -1 when failure happens */ int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out); /*! * \brief set the flag for gradient array state. * \param handle NDArray handle * \param state the new state. * \return 0 when success, -1 when failure happens */ int MXNDArraySetGradState(NDArrayHandle handle, int state); /*! * \brief set the flag for gradient array state. * \param handle NDArray handle * \param state the new state. * \return 0 when success, -1 when failure happens */ int MXNDArrayGetGradState(NDArrayHandle handle, int *out); //-------------------------------- // Part 2: functions on NDArray //-------------------------------- /*! * \brief list all the available functions handles * most user can use it to list all the needed functions * \param out_size the size of returned array * \param out_array the output function array * \return 0 when success, -1 when failure happens */ int MXListFunctions(mx_uint *out_size, FunctionHandle **out_array); /*! * \brief get the function handle by name * \param name the name of the function * \param out the corresponding function handle * \return 0 when success, -1 when failure happens */ int MXGetFunction(const char *name, FunctionHandle *out); /*! * \brief Get the information of the function handle. * \param fun The function handle. * \param name The returned name of the function. * \param description The returned description of the function. * \param num_args Number of arguments. * \param arg_names Name of the arguments. * \param arg_type_infos Type information about the arguments. * \param arg_descriptions Description information about the arguments. * \param return_type Return type of the function. * \return 0 when success, -1 when failure happens */ int MXFuncGetInfo(FunctionHandle fun, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions ); /*! * \brief get the argument requirements of the function * \param fun input function handle * \param num_use_vars how many NDArrays to be passed in as used_vars * \param num_scalars scalar variable is needed * \param num_mutate_vars how many NDArrays to be passed in as mutate_vars * \param type_mask the type mask of this function * \return 0 when success, -1 when failure happens * \sa MXFuncInvoke */ int MXFuncDescribe(FunctionHandle fun, mx_uint *out, mx_uint *out, mx_uint *out, int *out); /*! * \brief invoke a function, the array size of passed in arguments * must match the values in the * \param fun the function * \param use_vars the normal arguments passed to function * \param scalar_args the scalar qarguments * \param mutate_vars the mutate arguments * \return 0 when success, -1 when failure happens * \sa MXFuncDescribeArgs */ int MXFuncInvoke(FunctionHandle fun, NDArrayHandle *in, mx_float *in, NDArrayHandle *in); /*! * \brief invoke a function, the array size of passed in arguments * must match the values in the * \param fun the function * \param use_vars the normal arguments passed to function * \param scalar_args the scalar qarguments * \param mutate_vars the mutate arguments * \param num_params number of keyword parameters * \param param_keys keys for keyword parameters * \param param_vals values for keyword parameters * \return 0 when success, -1 when failure happens * \sa MXFuncDescribeArgs */ int MXFuncInvokeEx(FunctionHandle fun, NDArrayHandle *in, mx_float *in, NDArrayHandle *in, int num_params, char **keys, char **vals); /*! * \brief invoke a nnvm op and imperative function * \param creator the op * \param num_inputs number of input NDArrays * \param inputs input NDArrays * \param num_outputs number of output NDArrays * \param outputs output NDArrays * \param num_params number of keyword parameters * \param param_keys keys for keyword parameters * \param param_vals values for keyword parameters * \return 0 when success, -1 when failure happens */ int MXImperativeInvoke(AtomicSymbolCreator in, int num_inputs, NDArrayHandle *in, int *out_size, NDArrayHandle **out_array, int num_params, const char **keys, const char **vals); /*! * \brief set whether to record operator for autograd * \param is_train 1 when training, 0 when testing * \param prev returns the previous status before this set. * \return 0 when success, -1 when failure happens */ int MXAutogradSetIsTraining(int is_training, int* out); /*! * \brief mark NDArrays as variables to compute gradient for autograd * \param num_var number of variable NDArrays * \param var_handles variable NDArrays * \return 0 when success, -1 when failure happens */ int MXAutogradMarkVariables(mx_uint num_var, NDArrayHandle *in, mx_uint *in, NDArrayHandle *in); /*! * \brief compute the gradient of outputs w.r.t variables * \param num_output number of output NDArray * \param output_handles output NDArrays * \return 0 when success, -1 when failure happens */ int MXAutogradComputeGradient(mx_uint num_output, NDArrayHandle* in); /*! * \brief compute the gradient of outputs w.r.t variabels * \param num_output number of output NDArray * \param output_handles output NDArrays * \param ograd_handles head gradient for NDArrays * \param retain_graph whether to keep the graph after backward * \return 0 when success, -1 when failure happens */ int MXAutogradBackward(mx_uint num_output, NDArrayHandle* in, NDArrayHandle* in, int retain_graph); /*! * \brief create cached operator */ int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out); /*! * \brief free cached operator */ int MXFreeCachedOp(CachedOpHandle handle); /*! * \brief invoke cached operator */ int MXInvokeCachedOp(CachedOpHandle handle, int num_inputs, NDArrayHandle *in, int *out_size, NDArrayHandle **out_array); //-------------------------------------------- // Part 3: symbolic configuration generation //-------------------------------------------- /*! * \brief list all the available operator names, include entries. * \param out_size the size of returned array * \param out_array the output operator name array. * \return 0 when success, -1 when failure happens */ int MXListAllOpNames(mx_uint *out_size, const char ***out_array); /*! * \brief list all the available AtomicSymbolEntry * \param out_size the size of returned array * \param out_array the output AtomicSymbolCreator array * \return 0 when success, -1 when failure happens */ int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, AtomicSymbolCreator **out_array); /*! * \brief Get the name of an atomic symbol. * \param creator the AtomicSymbolCreator. * \param name The returned name of the creator. */ int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator in, const char **out); /*! * \brief Get the detailed information about atomic symbol. * \param creator the AtomicSymbolCreator. * \param name The returned name of the creator. * \param description The returned description of the symbol. * \param num_args Number of arguments. * \param arg_names Name of the arguments. * \param arg_type_infos Type informations about the arguments. * \param arg_descriptions Description information about the arguments. * \param key_var_num_args The keyword argument for specifying variable number of arguments. * When this parameter has non-zero length, the function allows variable number * of positional arguments, and will need the caller to pass it in in * MXSymbolCreateAtomicSymbol, * With key = key_var_num_args, and value = number of positional arguments. * \param return_type Return type of the function, can be Symbol or Symbol[] * \return 0 when success, -1 when failure happens */ int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator in, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions, const char **key_var_num_args ); /*! * \brief Create an AtomicSymbol. * \param creator the AtomicSymbolCreator * \param num_param the number of parameters * \param keys the keys to the params * \param vals the vals of the params * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator in, mx_uint num_param, const char **keys, const char **vals, SymbolHandle *out); /*! * \brief Create a Variable Symbol. * \param name name of the variable * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ int MXSymbolCreateVariable(const char *name, SymbolHandle *out); /*! * \brief Create a Symbol by grouping list of symbols together * \param num_symbols number of symbols to be grouped * \param symbols array of symbol handles * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ int MXSymbolCreateGroup(mx_uint num_symbols, SymbolHandle *in, SymbolHandle *out); /*! * \brief Load a symbol from a json file. * \param fname the file name. * \param out the output symbol. * \return 0 when success, -1 when failure happens */ int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out); /*! * \brief Load a symbol from a json string. * \param json the json string. * \param out the output symbol. * \return 0 when success, -1 when failure happens */ int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out); /*! * \brief Save a symbol into a json file. * \param symbol the input symbol. * \param fname the file name. * \return 0 when success, -1 when failure happens */ int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname); /*! * \brief Save a symbol into a json string * \param symbol the input symbol. * \param out_json output json string. * \return 0 when success, -1 when failure happens */ int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out); /*! * \brief Free the symbol handle. * \param symbol the symbol * \return 0 when success, -1 when failure happens */ int MXSymbolFree(SymbolHandle symbol); /*! * \brief Copy the symbol to another handle * \param symbol the source symbol * \param out used to hold the result of copy * \return 0 when success, -1 when failure happens */ int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out); /*! * \brief Print the content of symbol, used for debug. * \param symbol the symbol * \param out_str pointer to hold the output string of the printing. * \return 0 when success, -1 when failure happens */ int MXSymbolPrint(SymbolHandle symbol, const char **out); /*! * \brief Get string name from symbol * \param symbol the source symbol * \param out The result name. * \param success Whether the result is contained in out. * \return 0 when success, -1 when failure happens */ int MXSymbolGetName(SymbolHandle symbol, const char** out, int *out); /*! * \brief Get string attribute from symbol * \param symbol the source symbol * \param key The key of the symbol. * \param out The result attribute, can be NULL if the attribute do not exist. * \param success Whether the result is contained in out. * \return 0 when success, -1 when failure happens */ int MXSymbolGetAttr(SymbolHandle symbol, const char* key, const char** out, int *out); /*! * \brief Set string attribute from symbol. * NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph. * * Safe recommendaton: use immutable graph * - Only allow set attributes during creation of new symbol as optional parameter * * Mutable graph (be careful about the semantics): * - Allow set attr at any point. * - Mutating an attribute of some common node of two graphs can cause confusion from user. * * \param symbol the source symbol * \param key The key of the symbol. * \param value The value to be saved. * \return 0 when success, -1 when failure happens */ int MXSymbolSetAttr(SymbolHandle symbol, const char* in, const char* in); /*! * \brief Get all attributes from symbol, including all descendents. * \param symbol the source symbol * \param out_size The number of output attributes * \param out 2*out_size strings representing key value pairs. * \return 0 when success, -1 when failure happens */ int MXSymbolListAttr(SymbolHandle symbol, mx_uint *out_size, const char*** out_array2); /*! * \brief Get all attributes from symbol, excluding descendents. * \param symbol the source symbol * \param out_size The number of output attributes * \param out 2*out_size strings representing key value pairs. * \return 0 when success, -1 when failure happens */ int MXSymbolListAttrShallow(SymbolHandle symbol, mx_uint *out_size, const char*** out_array2); /*! * \brief List arguments in the symbol. * \param symbol the symbol * \param out_size output size * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ int MXSymbolListArguments(SymbolHandle symbol, mx_uint *out_size, const char ***out_array); /*! * \brief List returns in the symbol. * \param symbol the symbol * \param out_size output size * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ int MXSymbolListOutputs(SymbolHandle symbol, mx_uint *out_size, const char ***out_array); /*! * \brief Get a symbol that contains all the internals. * \param symbol The symbol * \param out The output symbol whose outputs are all the internals. * \return 0 when success, -1 when failure happens */ int MXSymbolGetInternals(SymbolHandle symbol, SymbolHandle *out); /*! * \brief Get a symbol that contains only direct children. * \param symbol The symbol * \param out The output symbol whose outputs are the direct children. * \return 0 when success, -1 when failure happens */ int MXSymbolGetChildren(SymbolHandle symbol, SymbolHandle *out); /*! * \brief Get index-th outputs of the symbol. * \param symbol The symbol * \param index the Index of the output. * \param out The output symbol whose outputs are the index-th symbol. * \return 0 when success, -1 when failure happens */ int MXSymbolGetOutput(SymbolHandle symbol, mx_uint index, SymbolHandle *out); /*! * \brief List auxiliary states in the symbol. * \param symbol the symbol * \param out_size output size * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ int MXSymbolListAuxiliaryStates(SymbolHandle symbol, mx_uint *out_size, const char ***out_array); /*! * \brief Compose the symbol on other symbols. * * This function will change the sym hanlde. * To achieve function apply behavior, copy the symbol first * before apply. * * \param sym the symbol to apply * \param name the name of symbol * \param num_args number of arguments * \param keys the key of keyword args (optional) * \param args arguments to sym * \return 0 when success, -1 when failure happens */ int MXSymbolCompose(SymbolHandle sym, const char *name, mx_uint num_args, const char** in, SymbolHandle* in); /*! * \brief Get the gradient graph of the symbol * * \param sym the symbol to get gradient * \param num_wrt number of arguments to get gradient * \param wrt the name of the arguments to get gradient * \param out the returned symbol that has gradient * \return 0 when success, -1 when failure happens */ int MXSymbolGrad(SymbolHandle sym, mx_uint num_wrt, const char** in, SymbolHandle* out); /*! * \brief infer shape of unknown input shapes given the known one. * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data * The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional. * * \param sym symbol handle * \param num_args numbe of input arguments. * \param keys the key of keyword args (optional) * \param arg_ind_ptr the head pointer of the rows in CSR * \param arg_shape_data the content of the CSR * \param in_shape_size sizeof the returning array of in_shapes * \param in_shape_ndim returning array of shape dimensions of eachs input shape. * \param in_shape_data returning array of pointers to head of the input shape. * \param out_shape_size sizeof the returning array of out_shapes * \param out_shape_ndim returning array of shape dimensions of eachs input shape. * \param out_shape_data returning array of pointers to head of the input shape. * \param aux_shape_size sizeof the returning array of aux_shapes * \param aux_shape_ndim returning array of shape dimensions of eachs auxiliary shape. * \param aux_shape_data returning array of pointers to head of the auxiliary shape. * \param complete whether infer shape completes or more information is needed. * \return 0 when success, -1 when failure happens */ int MXSymbolInferShape(SymbolHandle sym, mx_uint num_args, const char** in, const mx_uint *in, const mx_uint *in, mx_uint *in_shape_size, const mx_uint **in_shape_ndim, const mx_uint ***in_shape_data, mx_uint *out_shape_size, const mx_uint **out_shape_ndim, const mx_uint ***out_shape_data, mx_uint *aux_shape_size, const mx_uint **aux_shape_ndim, const mx_uint ***aux_shape_data, int *out); /*! * \brief partially infer shape of unknown input shapes given the known one. * * Return partially inferred results if not all shapes could be inferred. * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data * The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional. * * \param sym symbol handle * \param num_args numbe of input arguments. * \param keys the key of keyword args (optional) * \param arg_ind_ptr the head pointer of the rows in CSR * \param arg_shape_data the content of the CSR * \param in_shape_size sizeof the returning array of in_shapes * \param in_shape_ndim returning array of shape dimensions of eachs input shape. * \param in_shape_data returning array of pointers to head of the input shape. * \param out_shape_size sizeof the returning array of out_shapes * \param out_shape_ndim returning array of shape dimensions of eachs input shape. * \param out_shape_data returning array of pointers to head of the input shape. * \param aux_shape_size sizeof the returning array of aux_shapes * \param aux_shape_ndim returning array of shape dimensions of eachs auxiliary shape. * \param aux_shape_data returning array of pointers to head of the auxiliary shape. * \param complete whether infer shape completes or more information is needed. * \return 0 when success, -1 when failure happens */ int MXSymbolInferShapePartial(SymbolHandle sym, mx_uint num_args, const char** in, const mx_uint *in, const mx_uint *in, mx_uint *in_shape_size, const mx_uint **in_shape_ndim, const mx_uint ***in_shape_data, mx_uint *out_shape_size, const mx_uint **out_shape_ndim, const mx_uint ***out_shape_data, mx_uint *aux_shape_size, const mx_uint **aux_shape_ndim, const mx_uint ***aux_shape_data, int *out); /*! * \brief infer type of unknown input types given the known one. * The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data * The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional. * * \param sym symbol handle * \param num_args numbe of input arguments. * \param keys the key of keyword args (optional) * \param arg_type_data the content of the CSR * \param in_type_size sizeof the returning array of in_types * \param in_type_data returning array of pointers to head of the input type. * \param out_type_size sizeof the returning array of out_types * \param out_type_data returning array of pointers to head of the input type. * \param aux_type_size sizeof the returning array of aux_types * \param aux_type_data returning array of pointers to head of the auxiliary type. * \param complete whether infer type completes or more information is needed. * \return 0 when success, -1 when failure happens */ int MXSymbolInferType(SymbolHandle sym, mx_uint num_args, const char** in, const int *in, mx_uint *in_type_size, const int **in_type_data, mx_uint *out_type_size, const int **out_type_data, mx_uint *aux_type_size, const int **aux_type_data, int *out); //-------------------------------------------- // Part 4: Executor interface //-------------------------------------------- /*! * \brief Delete the executor * \param handle the executor. * \return 0 when success, -1 when failure happens */ int MXExecutorFree(ExecutorHandle handle); /*! * \brief Print the content of execution plan, used for debug. * \param handle the executor. * \param out_str pointer to hold the output string of the printing. * \return 0 when success, -1 when failure happens */ int MXExecutorPrint(ExecutorHandle handle, const char **out); /*! * \brief Executor forward method * * \param handle executor handle * \param is_train bool value to indicate whether the forward pass is for evaluation * \return 0 when success, -1 when failure happens */ int MXExecutorForward(ExecutorHandle handle, int is_train); /*! * \brief Excecutor run backward * * \param handle execute handle * \param len lenth * \param head_grads NDArray handle for heads' gradient * * \return 0 when success, -1 when failure happens */ int MXExecutorBackward(ExecutorHandle handle, mx_uint len, NDArrayHandle *in); /*! * \brief Get executor's head NDArray * * \param handle executor handle * \param out_size output ndarray vector size * \param out out put ndarray handles * \return 0 when success, -1 when failure happens */ int MXExecutorOutputs(ExecutorHandle handle, mx_uint *out_size, NDArrayHandle **out_array); /*! * \brief Generate Executor from symbol * * \param symbol_handle symbol handle * \param dev_type device type * \param dev_id device id * \param len length * \param in_args in args array * \param arg_grad_store arg grads handle array * \param grad_req_type grad req array * \param aux_states_len length of auxiliary states * \param aux_states auxiliary states array * \param out output executor handle * \return 0 when success, -1 when failure happens */ int MXExecutorBind(SymbolHandle symbol_handle, int dev_type, int dev_id, mx_uint len, NDArrayHandle *in, NDArrayHandle *in, mx_uint *in, mx_uint aux_states_len, NDArrayHandle *in, ExecutorHandle *out); /*! * \brief Generate Executor from symbol, * This is advanced function, allow specify group2ctx map. * The user can annotate "ctx_group" attribute to name each group. * * \param symbol_handle symbol handle * \param dev_type device type of default context * \param dev_id device id of default context * \param num_map_keys size of group2ctx map * \param map_keys keys of group2ctx map * \param map_dev_types device type of group2ctx map * \param map_dev_ids device id of group2ctx map * \param len length * \param in_args in args array * \param arg_grad_store arg grads handle array * \param grad_req_type grad req array * \param aux_states_len length of auxiliary states * \param aux_states auxiliary states array * \param out output executor handle * \return 0 when success, -1 when failure happens */ int MXExecutorBindX(SymbolHandle symbol_handle, int dev_type, int dev_id, mx_uint num_map_keys, const char** in, const int* in, const int* in, mx_uint len, NDArrayHandle *in, NDArrayHandle *in, mx_uint *in, mx_uint aux_states_len, NDArrayHandle *in, ExecutorHandle *out); /*! * \brief Generate Executor from symbol, * This is advanced function, allow specify group2ctx map. * The user can annotate "ctx_group" attribute to name each group. * * \param symbol_handle symbol handle * \param dev_type device type of default context * \param dev_id device id of default context * \param num_map_keys size of group2ctx map * \param map_keys keys of group2ctx map * \param map_dev_types device type of group2ctx map * \param map_dev_ids device id of group2ctx map * \param len length * \param in_args in args array * \param arg_grad_store arg grads handle array * \param grad_req_type grad req array * \param aux_states_len length of auxiliary states * \param aux_states auxiliary states array * \param shared_exec input executor handle for memory sharing * \param out output executor handle * \return 0 when success, -1 when failure happens */ int MXExecutorBindEX(SymbolHandle symbol_handle, int dev_type, int dev_id, mx_uint num_map_keys, const char** in, const int* in, const int* in, mx_uint len, NDArrayHandle *in, NDArrayHandle *in, mx_uint *in, mx_uint aux_states_len, NDArrayHandle *in, ExecutorHandle shared_exec, ExecutorHandle *out); int MXExecutorSimpleBind(SymbolHandle symbol_handle, int dev_type, int dev_id, const mx_uint num_g2c_keys, const char** in, // g2c_keys, const int* in, // g2c_dev_types, const int* in, // g2c_dev_ids, const mx_uint provided_grad_req_list_len, const char** in, // provided_grad_req_names, const char** in, // provided_grad_req_types, const mx_uint num_provided_arg_shapes, const char** in, // provided_arg_shape_names, const mx_uint* in, // provided_arg_shape_data, const mx_uint* in, // provided_arg_shape_idx, const mx_uint num_provided_arg_dtypes, const char** in, // provided_arg_dtype_names, const int* in, // provided_arg_dtypes, const mx_uint num_shared_arg_names, const char** in, // shared_arg_name_list, //------------ int* shared_buffer_len, const char** shared_buffer_name_list, NDArrayHandle* shared_buffer_handle_list, const char*** updated_shared_buffer_name_list, NDArrayHandle** updated_shared_buffer_handle_list, //------------------ mx_uint* num_in_args, NDArrayHandle** in_args, NDArrayHandle** arg_grads, //----------------- mx_uint* num_aux_states, NDArrayHandle** aux_states, //---------- ExecutorHandle shared_exec_handle, ExecutorHandle* out ); /*! * \brief set a call back to notify the completion of operation */ int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, void* callback_handle); //-------------------------------------------- // Part 5: IO Interface //-------------------------------------------- /*! * \brief List all the available iterator entries * \param out_size the size of returned iterators * \param out_array the output iteratos entries * \return 0 when success, -1 when failure happens */ int MXListDataIters(mx_uint *out_size, DataIterCreator **out_array); /*! * \brief Init an iterator, init with parameters * the array size of passed in arguments * \param handle of the iterator creator * \param num_param number of parameter * \param keys parameter keys * \param vals parameter values * \param out resulting iterator * \return 0 when success, -1 when failure happens */ int MXDataIterCreateIter(DataIterCreator handle, mx_uint num_param, const char **keys, const char **vals, DataIterHandle *out); /*! * \brief Get the detailed information about data iterator. * \param creator the DataIterCreator. * \param name The returned name of the creator. * \param description The returned description of the symbol. * \param num_args Number of arguments. * \param arg_names Name of the arguments. * \param arg_type_infos Type informations about the arguments. * \param arg_descriptions Description information about the arguments. * \return 0 when success, -1 when failure happens */ int MXDataIterGetIterInfo(DataIterCreator creator, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions); /*! * \brief Free the handle to the IO module * \param handle the handle pointer to the data iterator * \return 0 when success, -1 when failure happens */ int MXDataIterFree(DataIterHandle handle); /*! * \brief Move iterator to next position * \param handle the handle to iterator * \param out return value of next * \return 0 when success, -1 when failure happens */ int MXDataIterNext(DataIterHandle handle, int *out); /*! * \brief Call iterator.Reset * \param handle the handle to iterator * \return 0 when success, -1 when failure happens */ int MXDataIterBeforeFirst(DataIterHandle handle); /*! * \brief Get the handle to the NDArray of underlying data * \param handle the handle pointer to the data iterator * \param out handle to underlying data NDArray * \return 0 when success, -1 when failure happens */ int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out); /*! * \brief Get the image index by array. * \param handle the handle pointer to the data iterator * \param out_index output index of the array. * \param out_size output size of the array. * \return 0 when success, -1 when failure happens */ int MXDataIterGetIndex(DataIterHandle handle, uint64_t **out_index, uint64_t *out_size); /*! * \brief Get the padding number in current data batch * \param handle the handle pointer to the data iterator * \param pad pad number ptr * \return 0 when success, -1 when failure happens */ int MXDataIterGetPadNum(DataIterHandle handle, int *out); /*! * \brief Get the handle to the NDArray of underlying label * \param handle the handle pointer to the data iterator * \param out the handle to underlying label NDArray * \return 0 when success, -1 when failure happens */ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out); //-------------------------------------------- // Part 6: basic KVStore interface //-------------------------------------------- /*! * \brief Initialized ps-lite environment variables * \param num_vars number of variables to initialize * \param keys environment keys * \param vals environment values */ int MXInitPSEnv(mx_uint num_vars, const char **keys, const char **vals); /*! * \brief Create a kvstore * \param type the type of KVStore * \param out The output type of KVStore * \return 0 when success, -1 when failure happens */ int MXKVStoreCreate(const char *type, KVStoreHandle *out); /*! * \brief Delete a KVStore handle. * \param handle handle to the kvstore * \return 0 when success, -1 when failure happens */ int MXKVStoreFree(KVStoreHandle handle); /*! * \brief Init a list of (key,value) pairs in kvstore, where each key is a string * \param handle handle to the kvstore * \param num the number of key-value pairs * \param keys the list of keys * \param vals the list of values * \return 0 when success, -1 when failure happens */ int MXKVStoreInitEx(KVStoreHandle handle, mx_uint num, const char** in, NDArrayHandle* in); /*! * \brief Push a list of (key,value) pairs to kvstore, where each key is a string * \param handle handle to the kvstore * \param num the number of key-value pairs * \param keys the list of keys * \param vals the list of values * \param priority the priority of the action * \return 0 when success, -1 when failure happens */ int MXKVStorePushEx(KVStoreHandle handle, mx_uint num, const char** in, NDArrayHandle* in, int priority); /*! * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string * \param handle handle to the kvstore * \param num the number of key-value pairs * \param keys the list of keys * \param vals the list of values * \param priority the priority of the action * \return 0 when success, -1 when failure happens */ int MXKVStorePullEx(KVStoreHandle handle, mx_uint num, const char** in, NDArrayHandle* in, int priority); /*! * \brief user-defined updater for the kvstore * It's this updater's responsibility to delete \a recv and \a local * \param the key * \param recv the pushed value on this key * \param local the value stored on local on this key * \param handle The additional handle to the updater */ typedef void (MXKVStoreUpdater)(int key, NDArrayHandle recv, NDArrayHandle local, void *handle); /*! * \brief register an push updater * \param handle handle to the KVStore * \param updater udpater function * \param updater_handle The additional handle used to invoke the updater * \return 0 when success, -1 when failure happens */ int MXKVStoreSetUpdater(KVStoreHandle handle, MXKVStoreUpdater updater, void *callback_handle); /*! * \brief get the type of the kvstore * \param handle handle to the KVStore * \param type a string type * \return 0 when success, -1 when failure happens */ int MXKVStoreGetType(KVStoreHandle handle, const char** out); //-------------------------------------------- // Part 6: advanced KVStore for multi-machines //-------------------------------------------- /** * \brief return The rank of this node in its group, which is in [0, GroupSize). * * \param handle handle to the KVStore * \param ret the node rank * \return 0 when success, -1 when failure happens */ int MXKVStoreGetRank(KVStoreHandle handle, int *out); /** * \brief return The number of nodes in this group, which is * - number of workers if if `IsWorkerNode() == true`, * - number of servers if if `IsServerNode() == true`, * - 1 if `IsSchedulerNode() == true`, * \param handle handle to the KVStore * \param ret the group size * \return 0 when success, -1 when failure happens */ int MXKVStoreGetGroupSize(KVStoreHandle handle, int *out); /** * \brief return whether or not this process is a worker node. * \param ret 1 for yes, 0 for no * \return 0 when success, -1 when failure happens */ int MXKVStoreIsWorkerNode(int *out); /** * \brief return whether or not this process is a server node. * \param ret 1 for yes, 0 for no * \return 0 when success, -1 when failure happens */ int MXKVStoreIsServerNode(int *out); /** * \brief return whether or not this process is a scheduler node. * \param ret 1 for yes, 0 for no * \return 0 when success, -1 when failure happens */ int MXKVStoreIsSchedulerNode(int *out); /** * \brief global barrier among all worker machines * * \param handle handle to the KVStore * \return 0 when success, -1 when failure happens */ int MXKVStoreBarrier(KVStoreHandle handle); /** * \brief whether to do barrier when finalize * * \param handle handle to the KVStore * \param barrier_before_exit whether to do barrier when kvstore finalize * \return 0 when success, -1 when failure happens */ int MXKVStoreSetBarrierBeforeExit(KVStoreHandle handle, const int barrier_before_exit); /** * \brief the prototype of a server controller * \param head the head of the command * \param body the body of the command * \param controller_handle helper handle for implementing controller */ typedef void (MXKVStoreServerController)(int head, const char *body, void *controller_handle); /** * \return Run as server (or scheduler) * * \param handle handle to the KVStore * \param controller the user-defined server controller * \param controller_handle helper handle for implementing controller * \return 0 when success, -1 when failure happens */ int MXKVStoreRunServer(KVStoreHandle handle, MXKVStoreServerController controller, void *callback_handle); /** * \return Send a command to all server nodes * * \param handle handle to the KVStore * \param cmd_id the head of the command * \param cmd_body the body of the command * \return 0 when success, -1 when failure happens */ int MXKVStoreSendCommmandToServers(KVStoreHandle handle, int cmd_id, const char* cmd_body); /** * \brief Get the number of ps dead node(s) specified by {node_id} * * \param handle handle to the KVStore * \param node_id Can be a node group or a single node. * kScheduler = 1, kServerGroup = 2, kWorkerGroup = 4 * \param number Ouptut number of dead nodes * \param timeout_sec A node fails to send heartbeart in {timeout_sec} seconds * will be presumed as 'dead' */ int MXKVStoreGetNumDeadNode(KVStoreHandle handle, const int node_id, int *out, const int timeout_sec = 60); /** * \brief Create a RecordIO writer object * \param uri path to file * \param out handle pointer to the created object * \return 0 when success, -1 when failure happens */ int MXRecordIOWriterCreate(const char *uri, RecordIOHandle *out); /** * \brief Delete a RecordIO writer object * \param handle handle to RecordIO object * \return 0 when success, -1 when failure happens */ int MXRecordIOWriterFree(RecordIOHandle handle); /** * \brief Write a record to a RecordIO object * \param handle handle to RecordIO object * \param buf buffer to write * \param size size of buffer * \return 0 when success, -1 when failure happens */ int MXRecordIOWriterWriteRecord(RecordIOHandle handle, const char *buf, size_t size); /** * \brief Get the current writer pointer position * \param handle handle to RecordIO object * \param pos handle to output position * \return 0 when success, -1 when failure happens */ int MXRecordIOWriterTell(RecordIOHandle handle, size_t *out); /** * \brief Create a RecordIO reader object * \param uri path to file * \param out handle pointer to the created object * \return 0 when success, -1 when failure happens */ int MXRecordIOReaderCreate(const char *uri, RecordIOHandle *out); /** * \brief Delete a RecordIO reader object * \param handle handle to RecordIO object * \return 0 when success, -1 when failure happens */ int MXRecordIOReaderFree(RecordIOHandle handle); /** * \brief Write a record to a RecordIO object * \param handle handle to RecordIO object * \param buf pointer to return buffer * \param size point to size of buffer * \return 0 when success, -1 when failure happens */ int MXRecordIOReaderReadRecord(RecordIOHandle handle, char const **out_array, size_t *out_size); /** * \brief Set the current reader pointer position * \param handle handle to RecordIO object * \param pos target position * \return 0 when success, -1 when failure happens */ int MXRecordIOReaderSeek(RecordIOHandle handle, size_t pos); /** * \brief Create a MXRtc object */ int MXRtcCreate(char* name, mx_uint num_input, mx_uint num_output, char** in, char** in, NDArrayHandle* in, NDArrayHandle* in, char* kernel, RtcHandle *out); /** * \brief Run cuda kernel */ int MXRtcPush(RtcHandle handle, mx_uint num_input, mx_uint num_output, NDArrayHandle* in, NDArrayHandle* in, mx_uint gridDimX, mx_uint gridDimY, mx_uint gridDimZ, mx_uint blockDimX, mx_uint blockDimY, mx_uint blockDimZ); /** * \brief Delete a MXRtc object */ int MXRtcFree(RtcHandle handle); int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creator);