inline void Init()

in src/operator/cudnn_pooling-inl.h [158:279]


  inline void Init(mshadow::Stream<gpu> *s,
                   const std::vector<TBlob> &in_data,
                   const std::vector<TBlob> &out_data) {
    using namespace mshadow;
    #if CUDNN_MAJOR >= 5
    nan_prop_ = CUDNN_NOT_PROPAGATE_NAN;
    #endif
    CHECK_EQ(in_data.size(), 1U);
    CHECK_EQ(out_data.size(), 1U);
    if (!init_cudnn_) {
      init_cudnn_ = true;
      if (param_.kernel.ndim() == 2) {
        // 2d conv
        Tensor<gpu, 4, DType> data = in_data[pool_enum::kData].get<gpu, 4, DType>(s);
        Tensor<gpu, 4, DType> out = out_data[pool_enum::kOut].get<gpu, 4, DType>(s);
        mshadow::Shape<4> dshape = data.shape_;
        CUDNN_CALL(cudnnCreatePoolingDescriptor(&pooling_desc_));
        CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc_));
        CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc_));
        CUDNN_CALL(cudnnSetTensor4dDescriptor(in_desc_,
                                              CUDNN_TENSOR_NCHW,
                                              dtype_,
                                              data.shape_[0],
                                              data.shape_[1],
                                              data.shape_[2],
                                              data.shape_[3]));
        CUDNN_CALL(cudnnSetTensor4dDescriptor(out_desc_,
                                              CUDNN_TENSOR_NCHW,
                                              dtype_,
                                              out.shape_[0],
                                              out.shape_[1],
                                              out.shape_[2],
                                              out.shape_[3]));
        #if CUDNN_MAJOR >= 5
        CUDNN_CALL(cudnnSetPooling2dDescriptor(pooling_desc_,
                                               mode_,
                                               nan_prop_,
                                               param_.global_pool ? dshape[2] : param_.kernel[0],
                                               param_.global_pool ? dshape[3] : param_.kernel[1],
                                               param_.pad[0],
                                               param_.pad[1],
                                               param_.global_pool ? 1 : param_.stride[0],
                                               param_.global_pool ? 1 :param_.stride[1]));
        #else
        CUDNN_CALL(cudnnSetPooling2dDescriptor(pooling_desc_,
                                               mode_,
                                               param_.global_pool ? dshape[2] : param_.kernel[0],
                                               param_.global_pool ? dshape[3] : param_.kernel[1],
                                               param_.pad[0],
                                               param_.pad[1],
                                               param_.global_pool ? 1 : param_.stride[0],
                                               param_.global_pool ? 1 : param_.stride[1]));
        #endif
      } else {
        Tensor<gpu, 5, DType> data = in_data[pool_enum::kData].get<gpu, 5, DType>(s);
        Tensor<gpu, 5, DType> out = out_data[pool_enum::kOut].get<gpu, 5, DType>(s);
        CUDNN_CALL(cudnnCreatePoolingDescriptor(&pooling_desc_));
        CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc_));
        CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc_));
        std::vector<int> ishape = {static_cast<int>(data.shape_[0]),
                                   static_cast<int>(data.shape_[1]),
                                   static_cast<int>(data.shape_[2]),
                                   static_cast<int>(data.shape_[3]),
                                   static_cast<int>(data.shape_[4])};

        std::vector<int> istride = {static_cast<int>(ishape[1] * ishape[2] * ishape[3] * ishape[4]),
                                    static_cast<int>(ishape[2] * ishape[3] * ishape[4]),
                                    static_cast<int>(ishape[3] * ishape[4]),
                                    static_cast<int>(ishape[4]),
                                    1};

        std::vector<int> oshape = {static_cast<int>(out.shape_[0]),
                                   static_cast<int>(out.shape_[1]),
                                   static_cast<int>(out.shape_[2]),
                                   static_cast<int>(out.shape_[3]),
                                   static_cast<int>(out.shape_[4])};

        std::vector<int> ostride = {static_cast<int>(oshape[1] * oshape[2] * oshape[3] * oshape[4]),
                                    static_cast<int>(oshape[2] * oshape[3] * oshape[4]),
                                    static_cast<int>(oshape[3] * oshape[4]),
                                    static_cast<int>(oshape[4]),
                                    1};

        std::vector<int> kernel_vec = {param_.global_pool ? ishape[2] :
                                                            static_cast<int>(param_.kernel[0]),
                                       param_.global_pool ? ishape[3] :
                                                            static_cast<int>(param_.kernel[1]),
                                       param_.global_pool ? ishape[4] :
                                                            static_cast<int>(param_.kernel[2])};

        std::vector<int> pad_vec = {param_.global_pool ? 0 : static_cast<int>(param_.pad[0]),
                                    param_.global_pool ? 0 : static_cast<int>(param_.pad[1]),
                                    param_.global_pool ? 0 : static_cast<int>(param_.pad[2])};

        std::vector<int> stride_vec = {param_.global_pool ? 1 : static_cast<int>(param_.stride[0]),
                                       param_.global_pool ? 1 : static_cast<int>(param_.stride[1]),
                                       param_.global_pool ? 1 : static_cast<int>(param_.stride[2])};

        CUDNN_CALL(cudnnSetTensorNdDescriptor(in_desc_,
                                              dtype_,
                                              static_cast<int>(ishape.size()),
                                              &ishape[0],
                                              &istride[0]));
        CUDNN_CALL(cudnnSetTensorNdDescriptor(out_desc_,
                                              dtype_,
                                              static_cast<int>(oshape.size()),
                                              &oshape[0],
                                              &ostride[0]));
        #if CUDNN_MAJOR >= 5
        CUDNN_CALL(cudnnSetPoolingNdDescriptor(pooling_desc_,
                                               mode_,
                                               nan_prop_,
                                               static_cast<int>(kernel_vec.size()),
                                               &(kernel_vec[0]),
                                               &(pad_vec[0]),
                                               &(stride_vec[0])));
        #else
        LOG(FATAL) << "3D pooling only support CUDNN v5 and abouve";
        #endif
      }
    }
  }