virtual void Forward()

in src/operator/mkl/mkl_pooling-inl.h [153:288]


  virtual void Forward(const OpContext &ctx,
                       const std::vector<TBlob> &in_data,
                       const std::vector<OpReqType> &req,
                       const std::vector<TBlob> &out_data,
                       const std::vector<TBlob> &aux_args) {
    using namespace mshadow;
    using namespace mshadow::expr;
    CHECK_EQ(in_data.size(), 1);
    CHECK_EQ(out_data.size(), 1);
    Stream<xpu> *s = ctx.get_stream<xpu>();
    if (param_.kernel.ndim() >= 3) {
      LOG(FATAL) << "Not implmented";
    }
    Tensor<xpu, 4, DType> data = mkl_experimental_direct_get<xpu, 4, DType>(
      in_data[pool_enum::kData], s);
    Tensor<xpu, 4, DType> out = mkl_experimental_direct_get<xpu, 4, DType>(
      out_data[pool_enum::kOut], s);
    if (!init_mkldnn_) {
      LayerSetUp(data, out);
      init_mkldnn_ = true;
    }
    auto first_pass = false;
    if (poolingFwd == NULL) first_pass = true;

    dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax;

    switch (param_.pool_type) {
    case pool_enum::kMaxPooling:
      algorithm = dnnAlgorithmPoolingMax;
      break;
    case pool_enum::kAvgPooling:
      algorithm = (param_.pooling_convention == pool_enum::kValid) ?
          dnnAlgorithmPoolingAvgIncludePadding : dnnAlgorithmPoolingAvg;

      break;
    default:
      LOG(FATAL) << "Unknown pooling method.";
    }

    dnnError_t status;
    void* pooling_res[dnnResourceNumber];

    void* bottom_data = NULL;
#if MKL_EXPERIMENTAL == 1
    bottom_data =
          reinterpret_cast<void *>(mkl_prv_data<DType>(in_data[pool_enum::kData]));
#endif
    dnnBorder_t border_type = dnnBorderZerosAsymm;
    switch (param_.pooling_convention) {
    case pool_enum::kFull:
      border_type = dnnBorderZeros;
      break;
    case pool_enum::kValid:
      border_type = dnnBorderZerosAsymm;
      break;
    default:
      border_type = dnnBorderZerosAsymm;
      break;
    }
    if (NULL == bottom_data) {
      bottom_data = data.dptr_;
      if (NULL == poolingFwd) {
        status = dnnPoolingCreateForward<DType>(&poolingFwd, NULL,
                                                algorithm, fwd_bottom_data->layout_usr,
                                                kernel_size, kernel_stride,
                                                src_offset, border_type);
      CHECK_EQ(status, E_SUCCESS);
      // Now create poolingBwd
      status = dnnPoolingCreateBackward<DType>(&poolingBwd, NULL,
                                               algorithm, fwd_bottom_data->layout_usr,
                                               kernel_size, kernel_stride,
                                               src_offset, border_type);
      CHECK_EQ(status, E_SUCCESS);
      }
    }
#if MKL_EXPERIMENTAL == 1
    if (NULL != bottom_data) {
       if (NULL == poolingFwd) {
          std::shared_ptr<MKLMemHolder> bottom_data_mem = in_data[pool_enum::kData].Mkl_mem_;
          std::shared_ptr<PrvMemDescr> bottom_prv_descriptor =
            bottom_data_mem->get_prv_descriptor();
          CHECK_EQ(bottom_prv_descriptor->get_descr_type(),
                   PrvMemDescr::PRV_DESCR_MKL2017);
          std::shared_ptr<MKLData<DType> > mem_descr
            = std::static_pointer_cast<MKLData<DType>>(bottom_prv_descriptor);
          CHECK(mem_descr != nullptr);
          fwd_bottom_data = mem_descr;

          status = dnnPoolingCreateForward<DType>(&poolingFwd, NULL,
                                                  algorithm, fwd_bottom_data->layout_int,
                                                  kernel_size, kernel_stride,
                                                  src_offset, border_type);
          CHECK_EQ(status, E_SUCCESS);
          fwd_top_data->create_internal_layout(poolingFwd, dnnResourceDst);

          // Now create poolingBwd
          status = dnnPoolingCreateBackward<DType>(&poolingBwd, NULL,
                                                   algorithm, fwd_bottom_data->layout_int,
                                                   kernel_size, kernel_stride,
                                                   src_offset, border_type);
          CHECK_EQ(status, E_SUCCESS);
          bwd_top_diff->create_internal_layout(poolingFwd, dnnResourceDst);
          bwd_bottom_diff->create_internal_layout(poolingFwd, dnnResourceSrc);
        }
    }
#endif

    if (first_pass) {
      dnnLayout_t max_idx_datal = NULL;
      status = dnnLayoutCreateFromPrimitive<DType>(
          &max_idx_datal, poolingFwd, dnnResourceWorkspace);
      CHECK_EQ(status, E_SUCCESS);
      status = dnnAllocateBuffer<DType>(reinterpret_cast<void**>(&max_idx_data), max_idx_datal);
      CHECK_EQ(status, E_SUCCESS);
#if MKL_EXPERIMENTAL == 0
      fwd_bottom_data->create_internal_layout(poolingFwd, dnnResourceSrc);
      fwd_top_data->create_internal_layout(poolingFwd, dnnResourceDst);
      bwd_top_diff->create_internal_layout(poolingBwd, dnnResourceDiffDst);
      bwd_bottom_diff->create_internal_layout(poolingBwd, dnnResourceDiffSrc);
#endif
      dnnLayoutDelete<DType>(max_idx_datal);
      first_pass = false;
    }
    pooling_res[dnnResourceSrc] = bottom_data;
    pooling_res[dnnResourceWorkspace] = max_idx_data;

    pooling_res[dnnResourceDst] = fwd_top_data->get_output_ptr(
      out.dptr_, fwd_top_data, out_data[pool_enum::kOut]);
    status = dnnExecute<DType>(poolingFwd, pooling_res);
    CHECK_EQ(status, E_SUCCESS);
#if MKL_EXPERIMENTAL == 0
    if (fwd_top_data->conversion_needed()) {
      fwd_top_data->convert_from_prv(out.dptr_);
    }
#endif
  }