include/mxnet/operator.h (201 lines of code) (raw):

/*! * Copyright (c) 2015 by Contributors * \file operator.h * \brief Operator interface of mxnet. * \author Naiyan Wang */ #ifndef MXNET_OPERATOR_H_ #define MXNET_OPERATOR_H_ #include <dmlc/base.h> #include <dmlc/json.h> #include <dmlc/logging.h> #include <dmlc/registry.h> #include <nnvm/node.h> #include <vector> #include <map> #include <string> #include <utility> #include "./base.h" #include "./resource.h" namespace mxnet { /*! \brief operation request type to Forward and Backward */ enum OpReqType { /*! \brief no operation, do not write anything */ kNullOp, /*! \brief write gradient to provided space */ kWriteTo, /*! * \brief perform an inplace write, * Target shares memory with one of input arguments. * This option only happen when */ kWriteInplace, /*! \brief add to the provided space */ kAddTo }; /*! * \brief All the possible information needed by Operator.Forward and Backward * This is the superset of RunContext. * We use this data structure to bookkeep everything needed by Forward and Backward. * \sa Resource */ struct OpContext { /*! \brief whether it is training phase */ int is_train; /*! \brief RunContext related resources */ RunContext run_ctx; /*! \brief the callback when operation completes, used by asynchronize ops */ engine::CallbackOnComplete async_on_complete; /*! \brief Resources requested by the operator */ std::vector<Resource> requested; /*! * \brief get mshadow stream from Context * \return the mshadow stream * \tparam xpu the device type of the stream */ template<typename xpu> inline mshadow::Stream<xpu>* get_stream() const { return run_ctx.get_stream<xpu>(); } }; /*! * \brief Operator interface. * Operator defines basic operation unit of optimized computation graph in mxnet. * This interface relies on pre-allocated memory in TBlob, the caller need to set * the memory region in TBlob correctly before calling Forward and Backward. * * Operator is generated by OperatorProperty. * To add new operator(aka. layers of neural nets) to mxnet, developer need to create * a new OperatorProperty and its corresponding Operator. * * \sa TBlob, TShape, OperatorProperty */ class Operator { public: /*! \brief the execution type of the operator */ enum ExecType { /*! \brief Forward/Backward are synchronize calls */ kSync, /*! * \brief Forward/Backward are asynchronize, * will call OpContext.async_on_complete when operation finishes. */ kAsync, /*! * \brief Cross device copy operation, this is a special operator * That indicates copy across devices, the input and output can sit on different device. * In current implementation, copy operator is specially handled by executor. * This flag is used for special case treatment and future extension of different copy ops. */ kCrossDeviceCopy }; /*! \brief destructor */ virtual ~Operator() {} /*! * \brief perform a forward operation of Operator, save the output to TBlob. * \param ctx runtime context available to this call * \param in_data array of input data, it is const * \param req the request types of saving operation, can only be kWriteTo or kWriteInplace. * \param out_data array of output data, pointer is used to indicate that this is holder * the space of TBlob in out_data must be pre-allocated with InferShape * \param aux_states Auxiliary states of operator. Normally operator doesn't * need, epecial case like Batch Norm requires. * \sa OpReqType, OpContext */ virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data, const std::vector<OpReqType> &req, const std::vector<TBlob> &out_data, const std::vector<TBlob> &aux_states) = 0; /*! * \brief Perform a Backward Operation, write gradient to the in_grad. * * \note * Convention: * out_grad.size() == OperatorProperty.NumVisibleOutputs() * out_data.size() == OperatorProperty.NumOutputs() * out_data can contain additional invisible returns that remembers the * state carried from the Forward pass. For example mask in the dropout. * The gradients are passed from visible returns in this function. * * \par * Not all the TBlobs in the arguments will be available * if you override the DeclareBackwardDependency of corresponding OperatorProperty class. * Only the dependencies you declared will be available at corresponding position, * the rest of the parameters are simply dummy where you will get a nullptr. * You will be safe if you use the default DeclareBackwardDependency. * But only declare what you need will give engine more chance for optimization. * * \param ctx runtime context available to this call * \param out_grad the gradient value we get from of the Operator. * \param in_data the array of input data. * \param out_data the array of output data. * \param req request types of the saving operation, can be all types. * \param in_grad the array of gradient we need to write to. * \param aux_states Auxiliary states of operator. Normally operator doesn't need * \sa OperatorProperty, OpReqType, OpContext */ virtual void Backward(const OpContext &ctx, const std::vector<TBlob> &out_grad, const std::vector<TBlob> &in_data, const std::vector<TBlob> &out_data, const std::vector<OpReqType> &req, const std::vector<TBlob> &in_grad, const std::vector<TBlob> &aux_states) { LOG(FATAL) << "Backward is not implemented"; } /*! \return execution type of the operator */ virtual ExecType exec_type() const { return kSync; } }; #if DMLC_USE_CXX11 // OperatorProperty allows C++11, while Operator do not rely on it. /*! * \brief OperatorProperty is a object that stores all information about Operator. * It also contains method to generate context(device) specific operators. * * It also contains various functions that can be optimally overriden to * provide optimization chance for computation engine. */ class OperatorProperty { public: /*! * \brief virtual destructor */ virtual ~OperatorProperty() {} /*! * \brief Initialize the Operator by setting the parameters * This function need to be called before all other functions. * \param kwargs the keyword arguments parameters */ virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0; /*! * \brief Get a map representation of internal parameters. * This can be used by Init to recover the state of OperatorProperty. */ virtual std::map<std::string, std::string> GetParams() const = 0; /*! * \brief Get input arguments of the Operator. * \return vector of arguments. */ virtual std::vector<std::string> ListArguments() const { return {"data"}; } /*! * \brief Get name of output values of Operator * \return name of output values. */ virtual std::vector<std::string> ListOutputs() const { return {"output"}; } /*! * \brief Get name of auxiliary states of Operator * \return name of return values. */ virtual std::vector<std::string> ListAuxiliaryStates() const { return {}; } /*! \return number of real return values of the Operator */ virtual int NumOutputs() const { return this->ListOutputs().size(); } /*! * \brief get number of visible return values during Symbol creation. * If NumVisibleOutputs() = k, and NumOutputs() = n. * The first k returns will be presented in the resulting symbol. * * The rest of the returns can be used for auxiliary states for Backward. * For example, Dropout will return [data, mask], with NumVisibleOutputs() == 1. * So when user call sym = Dropout(input), only data is presented in sym. * But all the returns will be presented in out_data parameter of Backward if requested. * * \return number of default return values */ virtual int NumVisibleOutputs() const { return NumOutputs(); } /*! * \brief infer the shapes of outputs and unknown input arguments * \param in_shape the shape of input arguments of the operator * this should be of same length as the vector returned by DescribeArgs * in_shape allows unknown elements, which are checked by shape.ndim() == 0. * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape * For known shapes, InferShape will check shape consistency * * common practice: set the shape of data input, and usually weight's shape can be inferred * * \param out_shape the shape of outputs of the operator * InferShape will modify the vector to fill output TShape * \param aux_shape the shape of auxiliary states of the operator * InferShape will modify the vector to fill output TShape * \return true if the shape inference is successful, false if there is not enough information. * \throws dmlc::Error if the known arg_shapes are inconsistent. */ virtual bool InferShape(std::vector<TShape> *in_shape, std::vector<TShape> *out_shape, std::vector<TShape> *aux_shape) const = 0; /*! * \brief infer the data types of outputs and unknown input arguments * \param in_type the type of input arguments of the operator * this should be of same length as the vector returned by DescribeArgs * in_type allows unknown elements, which are checked by type.ndim() == 0. * For unknown types, Infertype will try to fill in the correct type in in_type * For known types, Infertype will check type consistency * * common practice: set the type of data input, and usually weight's type can be inferred * * \param out_type the type of outputs of the operator * Infertype will modify the vector to fill output Ttype * \param aux_type the type of auxiliary states of the operator * Infertype will modify the vector to fill output Ttype * \return true if the type inference is successful, false if there is not enough information. * \throws dmlc::Error if the known arg_types are inconsistent. */ virtual bool InferType(std::vector<int> *in_type, std::vector<int> *out_type, std::vector<int> *aux_type) const { CHECK_LE(in_type->size(), this->ListArguments().size()); int n_in = this->ListArguments().size(); for (unsigned i = 0; i < in_type->size(); ++i) { CHECK(in_type->at(i) == mshadow::default_type_flag || in_type->at(i) == -1) << "Unsupported data type " << in_type->at(i); } in_type->clear(); for (int i = 0; i < n_in; ++i ) in_type->push_back(mshadow::default_type_flag); int n_out = this->ListOutputs().size(); out_type->clear(); for (int i = 0; i < n_out; ++i ) out_type->push_back(mshadow::default_type_flag); int n_aux = this->ListAuxiliaryStates().size(); aux_type->clear(); for (int i = 0; i < n_aux; ++i ) aux_type->push_back(mshadow::default_type_flag); return true; } /*! * \brief Copy this OperatorProperty. * \return a pointer to the copied OperatorProperty */ virtual OperatorProperty* Copy() const = 0; /*! * \brief Create a Operator on specific context */ virtual Operator* CreateOperator(Context ctx) const = 0; /*! * \brief Create a Operator on specific context and input shape/type * \param ctx context of this operator * \param in_shape shape of the input ndarrays * \param in_type dtype of the input ndarrays * \return the created operator */ virtual Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape, std::vector<int> *in_type) const { std::vector<int> out_type, aux_type; std::vector<TShape> out_shape, aux_shape; out_type.resize(this->ListOutputs().size()); out_shape.resize(this->ListOutputs().size()); aux_type.resize(this->ListAuxiliaryStates().size()); aux_shape.resize(this->ListAuxiliaryStates().size()); CHECK(InferType(in_type, &out_type, &aux_type)); CHECK(InferShape(in_shape, &out_shape, &aux_shape)); return CreateOperator(ctx); } /*! * \brief return the type string of the Operator * subclasses override this function. * \return The type string. */ virtual std::string TypeString() const = 0; //-------------------------------------------------------- // All the below functions are optional to override. //-------------------------------------------------------- /*! * \brief Declare additional resource required in forward pass. * These additional resources will be presented in OpContext.requested * in the same order of the returned Resource. * \param in_shape The input shape to the operator, corresponds to shapes of in_data. * \return Additional resource request */ virtual std::vector<ResourceRequest> ForwardResource( const std::vector<TShape> &in_shape) const { return std::vector<ResourceRequest>(); } /*! * \brief Declare additional resource required in backward pass. * These additional resources will be presented in OpContext.requested * in the same order of the returned Resource. * \param in_shape The input shape to the operator, corresponds to shapes of in_data. * \return Additional resource request */ virtual std::vector<ResourceRequest> BackwardResource( const std::vector<TShape> &in_shape) const { return std::vector<ResourceRequest>(); } /*! * \brief Declare the input requirement of Backward pass. * * Only the returned list of variables will be used in Backward. * This function is used for memory optimization. * It is advised to override and only return what is actually needed. * If this function is not overriden, all the variables will be valid in Backward. * * \code * // The following code declares Backward need out_grad[0], in_data[0],in_data[1] * vector<int> BackwardInputs(const vector<int> &out_grad, * const vector<int> &in_data, * const vector<int> &out_data) const { * return {out_grad[0], in_data[0], in_data[1]}; * } * \endcode * \param out_grad gradient of outputs in backward pass. * \param in_data the input data in forward pass. * \param out_data the output data in forward pass. * \return an integer vector indicating the input requirments * \sa BackwardInputs */ virtual std::vector<int> DeclareBackwardDependency( const std::vector<int> &out_grad, const std::vector<int> &in_data, const std::vector<int> &out_data) const { // By default requires to see all the things. // remember to override this function to get a better performance. std::vector<int> ret = out_grad; ret.insert(ret.end(), in_data.begin(), in_data.end()); ret.insert(ret.end(), out_data.begin(), out_data.end()); return ret; } /*! * \brief Get possible forward inplace options. * This function enables optimization to reuse memory of inputs in output. * Only override when necessary, by default in-place is disabled. * * The reason for void* type in the out_data is to distinguish the order * of mappings between the two, compiler will report error when * in_data and out_data's order in the pair get reversed. * * \code * // The following code says out_data[0] can share data with in_data[0] * vector<pair<int, void*> > ForwardInplaceOption(const vector<int> &in_data, * const vector<void*> &out_data) const { * return {{in_data[0], out_data[0]}}; * } * \endcode * \param in_data The input data in forward pass. * \param out_data The output data in forward pass. * \return list of pair of that maps input->output, * indicating possible in place operations. */ virtual std::vector<std::pair<int, void*> > ForwardInplaceOption( const std::vector<int> &in_data, const std::vector<void*> &out_data) const { return std::vector<std::pair<int, void*> >(); } /*! * \brief Get possible backward inplace options. * This function enables optimization to reuse memory of inputs in output. * Only override when necessary, by default in-place is disabled. * * The reason for void* type in the in_grad is to distinguish the order * of mappings between the two, compiler will report error when * in_data and out_data's order in the pair get reversed. * * \code * // The following code says in_grad[0] can share data with in_data[0] * vector<pair<int,int> > BackwardInplaceOption( * const std::vector<int> &out_grad, * const std::vector<int> &in_data, * const std::vector<int> &out_data, * const std::vector<int> &in_grad) const { * return {in_data[0], in_grad[0]}}; * } * \endcode * \param in_data The input data in forward pass. * \param out_data The output data in forward pass. * \param in_grad Gradient of inputs in backward pass. * \param out_grad Gradient of outputs in backward pass. * \return list of pair of that maps input->output, * indicating possible in place operations. */ virtual std::vector<std::pair<int, void*> > BackwardInplaceOption( const std::vector<int> &out_grad, const std::vector<int> &in_data, const std::vector<int> &out_data, const std::vector<void*> &in_grad) const { return std::vector<std::pair<int, void*> >(); } /*! * \brief Get Backward Input Dependency for generic types of data. * Normally T can be pointer of Symbol::DataEntry, or NDArray. * This function will select the result list of T according to DeclareBackwardDependency. * * \param in_data the input data in forward pass. * \param out_data the output data in forward pass. * \param out_grad gradient of outputs in backward pass. * \tparam T the generic type parameter. * \return vector of inputs the Backward Operation depends on. * \sa DeclareBackwardDependency */ template<typename T> inline std::vector<T> BackwardInputs(const std::vector<T> &out_grad, const std::vector<T> &in_data, const std::vector<T> &out_data) const { int counter = 0; std::vector<int> out_grad_index(out_grad.size()); std::vector<int> in_data_index(in_data.size()); std::vector<int> out_data_index(out_data.size()); for (size_t i = 0; i < out_grad_index.size(); ++i) { out_grad_index[i] = counter++; } for (size_t i = 0; i < in_data_index.size(); ++i) { in_data_index[i] = counter++; } for (size_t i = 0; i < out_data_index.size(); ++i) { out_data_index[i] = counter++; } std::vector<T> all_data; all_data.insert(all_data.end(), out_grad.begin(), out_grad.end()); all_data.insert(all_data.end(), in_data.begin(), in_data.end()); all_data.insert(all_data.end(), out_data.begin(), out_data.end()); std::vector<int> ret_index = this->DeclareBackwardDependency( out_grad_index, in_data_index, out_data_index); std::vector<T> ret(ret_index.size()); for (size_t i = 0; i < ret_index.size(); ++i) { ret[i] = all_data[ret_index[i]]; } return ret; } /*! * \brief create OperatorProperty * \param type_name the type string of the OperatorProperty * \return a new constructed OperatorProperty */ static OperatorProperty *Create(const char* type_name); }; /*! \brief typedef the factory function of operator property */ typedef std::function<OperatorProperty *()> OperatorPropertyFactory; /*! * \brief Registry entry for OperatorProperty factory functions. */ struct OperatorPropertyReg : public dmlc::FunctionRegEntryBase<OperatorPropertyReg, OperatorPropertyFactory> { /*! * \brief Set key_var_num_args * When this is set, the API caller is required to pass in a * argument with key=key_num_args.c_str(), and value=num_args. * num_args is number of positional argument when calling the function. * * This is used to pass in length of positional arguments * for operators that can take variable length of input. * Most operators do not need to set this property. * * \param key the key name to be set */ inline OperatorPropertyReg& set_key_var_num_args(const std::string &key) { // NOLINT(*) this->key_var_num_args = key; return *this; } /*! * \brief Check if TypeString of the type matches the registered name */ inline OperatorPropertyReg& check_name() { OperatorProperty *p = this->body(); std::string type = p->TypeString(); delete p; CHECK_EQ(this->name, type) << "Register Name and TypeString mismatch, name=\"" << this->name << "\"," << " but TypeString=\"" << type <<"\""; return *this; } /*! \brief The key num_args name. */ std::string key_var_num_args; }; //--------------------------------------------------------------------------------- // The following part are API Registration of Operators // See also MXNET_REGISTER_SIMPLE_OP in operator_util.h for registering simple ops. //--------------------------------------------------------------------------------- /*! * \brief Macro to register OperatorProperty * * \code * // example of registering a fully connected operator * REGISTER_OP_PROPERTY(FullyConnected, FullyConnectedOpProp) * .describe("Fully connected layer"); * * \endcode */ #define MXNET_REGISTER_OP_PROPERTY(name, OperatorPropertyType) \ DMLC_REGISTRY_REGISTER(::mxnet::OperatorPropertyReg, OperatorPropertyReg, name) \ .set_body([]() { return new OperatorPropertyType(); }) \ .set_return_type("NDArray-or-Symbol") \ .check_name() #endif // DMLC_USE_CXX11 } // namespace mxnet #endif // MXNET_OPERATOR_H_