src/operator/contrib/multi_proposal.cc (44 lines of code) (raw):

/*! * Copyright (c) 2017 Microsoft * Licensed under The Apache-2.0 License [see LICENSE for details] * \file multi_proposal.cc * \brief * \author Xizhou Zhu */ #include "./multi_proposal-inl.h" namespace mxnet { namespace op { template<typename xpu> class MultiProposalOp : public Operator{ public: explicit MultiProposalOp(MultiProposalParam param) { this->param_ = param; } virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data, const std::vector<OpReqType> &req, const std::vector<TBlob> &out_data, const std::vector<TBlob> &aux_states) { LOG(FATAL) << "not implemented"; } virtual void Backward(const OpContext &ctx, const std::vector<TBlob> &out_grad, const std::vector<TBlob> &in_data, const std::vector<TBlob> &out_data, const std::vector<OpReqType> &req, const std::vector<TBlob> &in_grad, const std::vector<TBlob> &aux_states) { LOG(FATAL) << "not implemented"; } private: MultiProposalParam param_; }; // class MultiProposalOp template<> Operator *CreateOp<cpu>(MultiProposalParam param) { return new MultiProposalOp<cpu>(param); } Operator* MultiProposalProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(MultiProposalParam); MXNET_REGISTER_OP_PROPERTY(_contrib_MultiProposal, MultiProposalProp) .describe("Generate region proposals via RPN") .add_argument("cls_score", "NDArray-or-Symbol", "Score of how likely proposal is object.") .add_argument("bbox_pred", "NDArray-or-Symbol", "BBox Predicted deltas from anchors for proposals") .add_argument("im_info", "NDArray-or-Symbol", "Image size and scale.") .add_arguments(MultiProposalParam::__FIELDS__()); } // namespace op } // namespace mxnet