src/c_api/c_api.cc (881 lines of code) (raw):

/*! * Copyright (c) 2015 by Contributors * \file c_api.cc * \brief C API of mxnet */ #include <dmlc/base.h> #include <dmlc/logging.h> #include <dmlc/io.h> #include <dmlc/memory_io.h> #include <dmlc/recordio.h> #include <dmlc/omp.h> #include <mxnet/base.h> #include <mxnet/ndarray.h> #include <mxnet/operator.h> #include <mxnet/io.h> #include <mxnet/c_api.h> #include <mxnet/kvstore.h> #include <mxnet/mxrtc.h> #include <vector> #include <sstream> #include <string> #include <mutex> #include <memory> #include <functional> #include <utility> #include "./c_api_common.h" #include "../operator/custom/custom-inl.h" #include "../engine/profiler.h" using namespace mxnet; // Internal function to get the information // from function registry // Used to implement MXSymbolGetAtomicSymbolInfo and MXFuncGetInfo template<typename FunRegType> inline int MXAPIGetFunctionRegInfo(const FunRegType *e, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions, const char **return_type) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); *name = e->name.c_str(); *description = e->description.c_str(); *num_args = static_cast<mx_uint>(e->arguments.size()); if (return_type) *return_type = e->return_type.c_str(); ret->ret_vec_charp.clear(); for (size_t i = 0; i < e->arguments.size(); ++i) { ret->ret_vec_charp.push_back(e->arguments[i].name.c_str()); } for (size_t i = 0; i < e->arguments.size(); ++i) { ret->ret_vec_charp.push_back(e->arguments[i].type_info_str.c_str()); } for (size_t i = 0; i < e->arguments.size(); ++i) { ret->ret_vec_charp.push_back(e->arguments[i].description.c_str()); } *arg_names = dmlc::BeginPtr(ret->ret_vec_charp); *arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + e->arguments.size(); *arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (e->arguments.size() * 2); API_END(); } // NOTE: return value is added in API_END int MXRandomSeed(int seed) { API_BEGIN(); mxnet::RandomSeed(seed); API_END(); } int MXNotifyShutdown() { API_BEGIN(); Engine::Get()->NotifyShutdown(); API_END(); } int MXSetProfilerConfig(int mode, const char* filename) { // mode, kOnlySymbolic: 0, kAllOperator: 1 API_BEGIN(); #if MXNET_USE_PROFILER engine::Profiler::Get()->SetConfig(engine::Profiler::ProfilerMode(mode), std::string(filename)); #else LOG(FATAL) << "Need to compile with USE_PROFILER=1 for MXNet Profiler"; #endif API_END(); } int MXDumpProfile() { API_BEGIN(); #if MXNET_USE_PROFILER engine::Profiler *profiler = engine::Profiler::Get(); CHECK(profiler->IsEnableOutput()) << "Profiler haven't been run. Config and start profiler first"; engine::Profiler::Get()->DumpProfile(); #else LOG(FATAL) << "Need to compile with USE_PROFILER=1 for MXNet Profiler"; #endif API_END() } int MXSetProfilerState(int state) { // state, kNotRunning: 0, kRunning: 1 API_BEGIN(); #if MXNET_USE_PROFILER engine::Profiler::Get()->SetState(engine::Profiler::ProfilerState(state)); #else LOG(FATAL) << "Need to compile with USE_PROFILER=1 for MXNet Profiler"; #endif API_END(); } int MXSetNumOMPThreads(int thread_num) { API_BEGIN(); omp_set_num_threads(thread_num); API_END(); } int MXNDArrayCreateNone(NDArrayHandle *out) { API_BEGIN(); *out = new NDArray(); API_END(); } int MXNDArrayCreate(const mx_uint *shape, mx_uint ndim, int dev_type, int dev_id, int delay_alloc, NDArrayHandle *out) { API_BEGIN(); *out = new NDArray( TShape(shape, shape + ndim), Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id), delay_alloc != 0); API_END(); } int MXNDArrayCreateEx(const mx_uint *shape, mx_uint ndim, int dev_type, int dev_id, int delay_alloc, int dtype, NDArrayHandle *out) { API_BEGIN(); *out = new NDArray( TShape(shape, shape + ndim), Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id), delay_alloc != 0, dtype); API_END(); } int MXNDArrayLoadFromRawBytes(const void *buf, size_t size, NDArrayHandle *out) { NDArray *ptr = nullptr; API_BEGIN(); dmlc::MemoryFixedSizeStream strm((void*)buf, size); // NOLINT(*) ptr = new NDArray(); if (!ptr->Load(&strm)) { throw dmlc::Error("Invalid NDArray serialization format"); } *out = ptr; API_END_HANDLE_ERROR(delete ptr); } int MXNDArraySaveRawBytes(NDArrayHandle handle, size_t *out_size, const char **out_buf) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); ret->ret_str.resize(0); dmlc::MemoryStringStream strm(&ret->ret_str); static_cast<NDArray*>(handle)->Save(&strm); *out_size = ret->ret_str.length(); *out_buf = ret->ret_str.c_str(); API_END(); } int MXNDArraySyncCopyFromCPU(NDArrayHandle handle, const void *data, size_t size) { API_BEGIN(); static_cast<NDArray*>(handle)->SyncCopyFromCPU(data, size); API_END(); } int MXNDArraySyncCopyToCPU(NDArrayHandle handle, void *data, size_t size) { API_BEGIN(); static_cast<NDArray*>(handle)->SyncCopyToCPU(data, size); API_END(); } int MXNDArrayWaitToRead(NDArrayHandle handle) { API_BEGIN(); static_cast<NDArray*>(handle)->WaitToRead(); API_END(); } int MXNDArrayWaitToWrite(NDArrayHandle handle) { API_BEGIN(); static_cast<NDArray*>(handle)->WaitToWrite(); API_END(); } int MXNDArrayWaitAll() { API_BEGIN(); Engine::Get()->WaitForAll(); API_END(); } int MXNDArraySave(const char* fname, mx_uint num_args, NDArrayHandle* args, const char** keys) { API_BEGIN(); std::vector<NDArray> data(num_args); std::vector<std::string> names; for (mx_uint i = 0; i < num_args; ++i) { data[i] = *static_cast<NDArray*>(args[i]); } if (keys != nullptr) { names.resize(num_args); for (mx_uint i = 0; i < num_args; ++i) { names[i] = keys[i]; } } { std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w")); mxnet::NDArray::Save(fo.get(), data, names); } API_END(); } int MXNDArrayLoad(const char* fname, mx_uint *out_size, NDArrayHandle** out_arr, mx_uint *out_name_size, const char*** out_names) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); ret->ret_vec_str.clear(); API_BEGIN(); std::vector<NDArray> data; std::vector<std::string> &names = ret->ret_vec_str; { std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r")); mxnet::NDArray::Load(fi.get(), &data, &names); } ret->ret_handles.resize(data.size()); for (size_t i = 0; i < data.size(); ++i) { NDArray *ptr = new NDArray(); *ptr = data[i]; ret->ret_handles[i] = ptr; } ret->ret_vec_charp.resize(names.size()); for (size_t i = 0; i < names.size(); ++i) { ret->ret_vec_charp[i] = names[i].c_str(); } *out_size = static_cast<mx_uint>(data.size()); *out_arr = dmlc::BeginPtr(ret->ret_handles); *out_name_size = static_cast<mx_uint>(names.size()); *out_names = dmlc::BeginPtr(ret->ret_vec_charp); API_END(); } int MXNDArrayFree(NDArrayHandle handle) { API_BEGIN(); delete static_cast<NDArray*>(handle); API_END(); } int MXNDArraySlice(NDArrayHandle handle, mx_uint slice_begin, mx_uint slice_end, NDArrayHandle *out) { NDArray *ptr = new NDArray(); API_BEGIN(); *ptr = static_cast<NDArray*>(handle)->Slice( slice_begin, slice_end); *out = ptr; API_END_HANDLE_ERROR(delete ptr); } int MXNDArrayAt(NDArrayHandle handle, mx_uint idx, NDArrayHandle *out) { NDArray *ptr = new NDArray(); API_BEGIN(); *ptr = static_cast<NDArray*>(handle)->At(idx); *out = ptr; API_END_HANDLE_ERROR(delete ptr); } MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle, int ndim, int *dims, NDArrayHandle *out) { NDArray *ptr = new NDArray(); API_BEGIN(); NDArray *arr = static_cast<NDArray*>(handle); TShape new_shape(dims, dims+ndim); int size = 1; int pos = -1; for (int i = 0; i < ndim; ++i) { int dim = dims[i]; if (dim == -1) { CHECK_EQ(pos, -1) << "Invalid new shape " << new_shape << ": more than one dimensions are -1"; pos = i; } else { if (dim == 0) { CHECK_LT(i, arr->shape().ndim()) << "Invalid new shape " << new_shape << ": 0 dimension exceeds original shape " << arr->shape(); dim = arr->shape()[i]; } size *= dim; new_shape[i] = dim; } } if (pos >= 0) { new_shape[pos] = arr->shape().Size() / size; } *ptr = arr->Reshape(new_shape); *out = ptr; API_END_HANDLE_ERROR(delete ptr); } int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); NDArray *arr = static_cast<NDArray*>(handle); if (!arr->is_none()) { const TShape &s = arr->shape(); *out_dim = s.ndim(); std::vector<uint32_t>& buffer = ret->arg_shape_buffer; buffer.resize(s.ndim()); nnvm::ShapeTypeCast(s.begin(), s.end(), buffer.data()); *out_pdata = buffer.data(); } else { *out_dim = 0; } API_END(); } int MXNDArrayGetData(NDArrayHandle handle, void **out_pdata) { API_BEGIN(); NDArray *arr = static_cast<NDArray*>(handle); if (!arr->is_none()) { CHECK(arr->ctx().dev_mask() == cpu::kDevMask) << "MXNDArrayGetData can only be called for NDArray on CPU"; const TBlob &b = arr->data(); CHECK(b.CheckContiguous()); MSHADOW_REAL_TYPE_SWITCH(arr->dtype(), DType, { *out_pdata = b.FlatTo2D<cpu, DType>().dptr_; }); } else { *out_pdata = nullptr; } API_END(); } int MXNDArrayGetDType(NDArrayHandle handle, int *out_dtype) { API_BEGIN(); NDArray *arr = static_cast<NDArray*>(handle); if (!arr->is_none()) { *out_dtype = arr->dtype(); } else { *out_dtype = -1; } API_END(); } int MXNDArrayGetContext(NDArrayHandle handle, int *out_dev_type, int *out_dev_id) { API_BEGIN(); NDArray *arr = static_cast<NDArray*>(handle); if (!arr->is_none()) { const Context &ctx = arr->ctx(); *out_dev_type = ctx.dev_type; *out_dev_id = ctx.dev_id; } else { *out_dev_type = 0; *out_dev_id = 0; } API_END(); } int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out) { API_BEGIN(); NDArray *arr = static_cast<NDArray*>(handle); *out = new NDArray(arr->Detach()); API_END(); } int MXNDArraySetGradState(NDArrayHandle handle, int state) { API_BEGIN(); NDArray *arr = static_cast<NDArray*>(handle); arr->set_fresh_out_grad(static_cast<bool>(state)); API_END(); } int MXNDArrayGetGradState(NDArrayHandle handle, int *out) { API_BEGIN(); NDArray *arr = static_cast<NDArray*>(handle); *out = arr->fresh_out_grad(); API_END(); } int MXListFunctions(mx_uint *out_size, FunctionHandle **out_array) { API_BEGIN(); auto &vec = dmlc::Registry<NDArrayFunctionReg>::List(); *out_size = static_cast<mx_uint>(vec.size()); *out_array = (FunctionHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); } int MXGetFunction(const char *name, FunctionHandle *out) { API_BEGIN(); *out = dmlc::Registry<NDArrayFunctionReg>::Find(name); API_END(); } int MXFuncGetInfo(FunctionHandle fun, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions, const char **return_type) { return MXAPIGetFunctionRegInfo(static_cast<const NDArrayFunctionReg *>(fun), name, description, num_args, arg_names, arg_type_infos, arg_descriptions, return_type); } int MXFuncDescribe(FunctionHandle fun, mx_uint *num_use_vars, mx_uint *num_scalars, mx_uint *num_mutate_vars, int *type_mask) { API_BEGIN(); auto *f = static_cast<const NDArrayFunctionReg*>(fun); *num_use_vars = f->num_use_vars; *num_scalars = f->num_scalars; *num_mutate_vars = f->num_mutate_vars; *type_mask = f->type_mask; API_END(); } int MXFuncInvoke(FunctionHandle fun, NDArrayHandle *use_vars, mx_float *scalar_args, NDArrayHandle *mutate_vars) { API_BEGIN(); auto *f = static_cast<const NDArrayFunctionReg*>(fun); f->body((NDArray**)(use_vars), // NOLINT(*) scalar_args, (NDArray**)(mutate_vars), // NOLINT(*) 0, NULL, NULL); API_END(); } int MXFuncInvokeEx(FunctionHandle fun, NDArrayHandle *use_vars, mx_float *scalar_args, NDArrayHandle *mutate_vars, int num_params, char **param_keys, char **param_vals) { API_BEGIN(); auto *f = static_cast<const NDArrayFunctionReg*>(fun); f->body((NDArray**)(use_vars), // NOLINT(*) scalar_args, (NDArray**)(mutate_vars), // NOLINT(*) num_params, param_keys, param_vals); API_END(); } //-------------------------------------------- // Part 5: IO Interface //-------------------------------------------- int MXListDataIters(mx_uint *out_size, DataIterCreator **out_array) { API_BEGIN(); auto &vec = dmlc::Registry<DataIteratorReg>::List(); *out_size = static_cast<mx_uint>(vec.size()); *out_array = (DataIterCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); } int MXDataIterGetIterInfo(DataIterCreator creator, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions) { DataIteratorReg *e = static_cast<DataIteratorReg *>(creator); return MXAPIGetFunctionRegInfo(e, name, description, num_args, arg_names, arg_type_infos, arg_descriptions, NULL); } int MXDataIterCreateIter(DataIterCreator creator, mx_uint num_param, const char **keys, const char **vals, DataIterHandle *out) { IIterator<DataBatch> *iter = nullptr; API_BEGIN(); DataIteratorReg *e = static_cast<DataIteratorReg *>(creator); iter = e->body(); std::vector<std::pair<std::string, std::string> > kwargs; for (mx_uint i = 0; i < num_param; ++i) { kwargs.push_back({std::string(keys[i]), std::string(vals[i])}); } iter->Init(kwargs); *out = iter; API_END_HANDLE_ERROR(delete iter); } int MXDataIterFree(DataIterHandle handle) { API_BEGIN(); delete static_cast<IIterator<DataBatch> *>(handle); API_END(); } int MXDataIterBeforeFirst(DataIterHandle handle) { API_BEGIN(); static_cast<IIterator<DataBatch>* >(handle)->BeforeFirst(); API_END(); } int MXDataIterNext(DataIterHandle handle, int *out) { API_BEGIN(); *out = static_cast<IIterator<DataBatch>* >(handle)->Next(); API_END(); } int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { API_BEGIN(); const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value(); NDArray* pndarray = new NDArray(); // temp hack to make label 1D // TODO(tianjun) make label 1D when label_width=0 TShape shape = db.data[1].shape(); if (shape[1] == 1) { *pndarray = db.data[1].Reshape(mshadow::Shape1(shape[0])); } else { *pndarray = db.data[1]; } *out = pndarray; API_END(); } int MXDataIterGetIndex(DataIterHandle handle, uint64_t **out_index, uint64_t *out_size) { API_BEGIN(); const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value(); *out_size = db.index.size(); *out_index = const_cast<uint64_t*>(db.index.data()); API_END(); } int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) { API_BEGIN(); const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value(); NDArray* pndarray = new NDArray(); *pndarray = db.data[0]; *out = pndarray; API_END(); } int MXDataIterGetPadNum(DataIterHandle handle, int *pad) { API_BEGIN(); const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value(); *pad = db.num_batch_padd; API_END(); } int MXKVStoreCreate(const char *type, KVStoreHandle *out) { API_BEGIN(); *out = KVStore::Create(type); API_END(); } int MXKVStoreFree(KVStoreHandle handle) { API_BEGIN(); delete static_cast<KVStore*>(handle); API_END(); } int MXKVStoreInit(KVStoreHandle handle, mx_uint num, const int* keys, NDArrayHandle* vals) { API_BEGIN(); std::vector<int> v_keys(num); std::vector<NDArray> v_vals(num); for (mx_uint i = 0; i < num; ++i) { v_keys[i] = keys[i]; v_vals[i] = *static_cast<NDArray*>(vals[i]); } static_cast<KVStore*>(handle)->Init(v_keys, v_vals); API_END(); } int MXKVStoreInitEx(KVStoreHandle handle, mx_uint num, const char** keys, NDArrayHandle* vals) { API_BEGIN(); std::vector<std::string> v_keys(num); std::vector<NDArray> v_vals(num); for (mx_uint i = 0; i < num; ++i) { v_keys[i] = keys[i]; v_vals[i] = *static_cast<NDArray*>(vals[i]); } static_cast<KVStore*>(handle)->Init(v_keys, v_vals); API_END(); } int MXKVStorePush(KVStoreHandle handle, mx_uint num, const int* keys, NDArrayHandle* vals, int priority) { API_BEGIN(); std::vector<int> v_keys(num); std::vector<NDArray> v_vals(num); for (mx_uint i = 0; i < num; ++i) { v_keys[i] = keys[i]; v_vals[i] = *static_cast<NDArray*>(vals[i]); } static_cast<KVStore*>(handle)->Push(v_keys, v_vals, priority); API_END(); } int MXKVStorePushEx(KVStoreHandle handle, mx_uint num, const char** keys, NDArrayHandle* vals, int priority) { API_BEGIN(); std::vector<std::string> v_keys(num); std::vector<NDArray> v_vals(num); for (mx_uint i = 0; i < num; ++i) { v_keys[i] = keys[i]; v_vals[i] = *static_cast<NDArray*>(vals[i]); } static_cast<KVStore*>(handle)->Push(v_keys, v_vals, priority); API_END(); } int MXKVStorePull(KVStoreHandle handle, mx_uint num, const int* keys, NDArrayHandle* vals, int priority) { API_BEGIN(); std::vector<int> v_keys(num); std::vector<NDArray*> v_vals(num); for (mx_uint i = 0; i < num; ++i) { v_keys[i] = keys[i]; v_vals[i] = static_cast<NDArray*>(vals[i]); } static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority); API_END(); } int MXKVStorePullEx(KVStoreHandle handle, mx_uint num, const char** keys, NDArrayHandle* vals, int priority) { API_BEGIN(); std::vector<std::string> v_keys(num); std::vector<NDArray*> v_vals(num); for (mx_uint i = 0; i < num; ++i) { v_keys[i] = keys[i]; v_vals[i] = static_cast<NDArray*>(vals[i]); } static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority); API_END(); } int MXKVStoreSetUpdater(KVStoreHandle handle, MXKVStoreUpdater updater, void* updater_handle) { API_BEGIN(); MXKVStoreUpdater * updater_temp = updater; void* updater_handle_temp = updater_handle; std::function<void(int, const NDArray&, NDArray*)> updt = [updater_temp, updater_handle_temp](int key, const NDArray& recv, NDArray* local) { NDArray* recv_copy = new NDArray(); *recv_copy = recv; NDArray* local_copy = new NDArray(); *local_copy = *local; updater_temp(key, recv_copy, local_copy, updater_handle_temp); }; static_cast<KVStore*>(handle)->set_updater(updt); API_END(); } int MXKVStoreGetRank(KVStoreHandle handle, int *rank) { API_BEGIN(); *rank = static_cast<KVStore*>(handle)->get_rank(); API_END(); } int MXKVStoreGetGroupSize(KVStoreHandle handle, int *size) { API_BEGIN(); *size = static_cast<KVStore*>(handle)->get_group_size(); API_END(); } int MXKVStoreBarrier(KVStoreHandle handle) { API_BEGIN(); static_cast<KVStore*>(handle)->Barrier(); API_END(); } int MXKVStoreSetBarrierBeforeExit(KVStoreHandle handle, const int barrier_before_exit) { API_BEGIN(); static_cast<KVStore*>(handle)->set_barrier_before_exit(barrier_before_exit); API_END(); } int MXInitPSEnv(mx_uint num_vars, const char **keys, const char **vals) { API_BEGIN(); std::unordered_map<std::string, std::string> kwargs; for (mx_uint i = 0; i < num_vars; ++i) { kwargs[std::string(keys[i])] = std::string(vals[i]); } KVStore::InitPSEnv(kwargs); API_END(); } int MXKVStoreIsWorkerNode(int *ret) { API_BEGIN(); *ret = KVStore::IsWorkerNode(); API_END(); } int MXKVStoreIsServerNode(int *ret) { API_BEGIN(); *ret = KVStore::IsServerNode(); API_END(); } int MXKVStoreIsSchedulerNode(int *ret) { API_BEGIN(); *ret = KVStore::IsSchedulerNode(); API_END(); } int MXKVStoreRunServer(KVStoreHandle handle, MXKVStoreServerController controller, void *controller_handle) { API_BEGIN(); MXKVStoreServerController *controller_temp = controller; void *controller_handle_temp = controller_handle; auto ctrl = [controller_temp, controller_handle_temp](int head, const std::string& body) { controller_temp(head, body.c_str(), controller_handle_temp); }; static_cast<KVStore*>(handle)->RunServer(ctrl); API_END(); } int MXKVStoreSendCommmandToServers(KVStoreHandle handle, int cmd_id, const char* cmd_body) { API_BEGIN(); static_cast<KVStore*>(handle)->SendCommandToServers( cmd_id, std::string(cmd_body)); API_END(); } int MXKVStoreGetType(KVStoreHandle handle, const char** type) { API_BEGIN(); *CHECK_NOTNULL(type) = static_cast<KVStore*>(handle)->type().c_str(); API_END(); } int MXKVStoreGetNumDeadNode(KVStoreHandle handle, const int node_id, int *number, const int timeout_sec) { API_BEGIN(); *number = static_cast<KVStore*>(handle)->get_num_dead_node(node_id, timeout_sec); API_END(); } struct MXRecordIOContext { dmlc::RecordIOWriter *writer; dmlc::RecordIOReader *reader; dmlc::Stream *stream; std::string *read_buff; }; int MXRecordIOWriterCreate(const char *uri, RecordIOHandle *out) { API_BEGIN(); dmlc::Stream *stream = dmlc::Stream::Create(uri, "w"); MXRecordIOContext *context = new MXRecordIOContext; context->writer = new dmlc::RecordIOWriter(stream); context->reader = NULL; context->stream = stream; context->read_buff = NULL; *out = reinterpret_cast<RecordIOHandle>(context); API_END(); } int MXRecordIOWriterFree(RecordIOHandle handle) { API_BEGIN(); MXRecordIOContext *context = reinterpret_cast<MXRecordIOContext*>(handle); delete context->writer; delete context->stream; delete context; API_END(); } int MXRecordIOWriterWriteRecord(RecordIOHandle handle, const char *buf, size_t size) { API_BEGIN(); MXRecordIOContext *context = reinterpret_cast<MXRecordIOContext*>(handle); context->writer->WriteRecord(reinterpret_cast<const void*>(buf), size); API_END(); } int MXRecordIOWriterTell(RecordIOHandle handle, size_t *pos) { API_BEGIN(); MXRecordIOContext *context = reinterpret_cast<MXRecordIOContext*>(handle); *pos = context->writer->Tell(); API_END(); } int MXRecordIOReaderCreate(const char *uri, RecordIOHandle *out) { API_BEGIN(); dmlc::Stream *stream = dmlc::Stream::Create(uri, "r"); MXRecordIOContext *context = new MXRecordIOContext; context->reader = new dmlc::RecordIOReader(stream); context->writer = NULL; context->stream = stream; context->read_buff = new std::string(); *out = reinterpret_cast<RecordIOHandle>(context); API_END(); } int MXRecordIOReaderFree(RecordIOHandle handle) { API_BEGIN(); MXRecordIOContext *context = reinterpret_cast<MXRecordIOContext*>(handle); delete context->reader; delete context->stream; delete context->read_buff; delete context; API_END(); } int MXRecordIOReaderReadRecord(RecordIOHandle handle, char const **buf, size_t *size) { API_BEGIN(); MXRecordIOContext *context = reinterpret_cast<MXRecordIOContext*>(handle); if (context->reader->NextRecord(context->read_buff)) { *buf = context->read_buff->c_str(); *size = context->read_buff->size(); } else { *buf = NULL; *size = 0; } API_END(); } int MXRecordIOReaderSeek(RecordIOHandle handle, size_t pos) { API_BEGIN(); MXRecordIOContext *context = reinterpret_cast<MXRecordIOContext*>(handle); context->reader->Seek(pos); API_END(); } int MXRtcCreate(char* name, mx_uint num_input, mx_uint num_output, char** input_names, char** output_names, NDArrayHandle* inputs, NDArrayHandle* outputs, char* kernel, RtcHandle *out) { API_BEGIN(); #if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC)) std::vector<std::pair<std::string, NDArray> > input, output; for (mx_uint i = 0; i < num_input; ++i) { input.push_back(std::pair<std::string, NDArray>(input_names[i], *reinterpret_cast<NDArray*>(inputs[i]))); } for (mx_uint i = 0; i < num_output; ++i) { output.push_back(std::pair<std::string, NDArray>(output_names[i], *reinterpret_cast<NDArray*>(outputs[i]))); } MXRtc *rtc = new MXRtc(name, input, output, kernel); *out = reinterpret_cast<RtcHandle>(rtc); #else LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc."; #endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC)) API_END(); } int MXRtcPush(RtcHandle handle, mx_uint num_input, mx_uint num_output, NDArrayHandle* inputs, NDArrayHandle* outputs, mx_uint gridDimX, mx_uint gridDimY, mx_uint gridDimZ, mx_uint blockDimX, mx_uint blockDimY, mx_uint blockDimZ) { API_BEGIN(); #if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC)) std::vector<NDArray> input, output; for (mx_uint i = 0; i < num_input; ++i) { input.push_back(*reinterpret_cast<NDArray*>(inputs[i])); } for (mx_uint i = 0; i < num_output; ++i) { output.push_back(*reinterpret_cast<NDArray*>(outputs[i])); } reinterpret_cast<MXRtc*>(handle)->push(input, output, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ); #else LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc."; #endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC)) API_END(); } int MXRtcFree(RtcHandle handle) { API_BEGIN(); #if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC)) delete reinterpret_cast<MXRtc*>(handle); #else LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc."; #endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC)) API_END(); } int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creator) { API_BEGIN(); mxnet::op::CustomOpProp::Register(op_type, creator); API_END(); }