in src/operator/mkl/mkl_pooling-inl.h [153:288]
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(in_data.size(), 1);
CHECK_EQ(out_data.size(), 1);
Stream<xpu> *s = ctx.get_stream<xpu>();
if (param_.kernel.ndim() >= 3) {
LOG(FATAL) << "Not implmented";
}
Tensor<xpu, 4, DType> data = mkl_experimental_direct_get<xpu, 4, DType>(
in_data[pool_enum::kData], s);
Tensor<xpu, 4, DType> out = mkl_experimental_direct_get<xpu, 4, DType>(
out_data[pool_enum::kOut], s);
if (!init_mkldnn_) {
LayerSetUp(data, out);
init_mkldnn_ = true;
}
auto first_pass = false;
if (poolingFwd == NULL) first_pass = true;
dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax;
switch (param_.pool_type) {
case pool_enum::kMaxPooling:
algorithm = dnnAlgorithmPoolingMax;
break;
case pool_enum::kAvgPooling:
algorithm = (param_.pooling_convention == pool_enum::kValid) ?
dnnAlgorithmPoolingAvgIncludePadding : dnnAlgorithmPoolingAvg;
break;
default:
LOG(FATAL) << "Unknown pooling method.";
}
dnnError_t status;
void* pooling_res[dnnResourceNumber];
void* bottom_data = NULL;
#if MKL_EXPERIMENTAL == 1
bottom_data =
reinterpret_cast<void *>(mkl_prv_data<DType>(in_data[pool_enum::kData]));
#endif
dnnBorder_t border_type = dnnBorderZerosAsymm;
switch (param_.pooling_convention) {
case pool_enum::kFull:
border_type = dnnBorderZeros;
break;
case pool_enum::kValid:
border_type = dnnBorderZerosAsymm;
break;
default:
border_type = dnnBorderZerosAsymm;
break;
}
if (NULL == bottom_data) {
bottom_data = data.dptr_;
if (NULL == poolingFwd) {
status = dnnPoolingCreateForward<DType>(&poolingFwd, NULL,
algorithm, fwd_bottom_data->layout_usr,
kernel_size, kernel_stride,
src_offset, border_type);
CHECK_EQ(status, E_SUCCESS);
// Now create poolingBwd
status = dnnPoolingCreateBackward<DType>(&poolingBwd, NULL,
algorithm, fwd_bottom_data->layout_usr,
kernel_size, kernel_stride,
src_offset, border_type);
CHECK_EQ(status, E_SUCCESS);
}
}
#if MKL_EXPERIMENTAL == 1
if (NULL != bottom_data) {
if (NULL == poolingFwd) {
std::shared_ptr<MKLMemHolder> bottom_data_mem = in_data[pool_enum::kData].Mkl_mem_;
std::shared_ptr<PrvMemDescr> bottom_prv_descriptor =
bottom_data_mem->get_prv_descriptor();
CHECK_EQ(bottom_prv_descriptor->get_descr_type(),
PrvMemDescr::PRV_DESCR_MKL2017);
std::shared_ptr<MKLData<DType> > mem_descr
= std::static_pointer_cast<MKLData<DType>>(bottom_prv_descriptor);
CHECK(mem_descr != nullptr);
fwd_bottom_data = mem_descr;
status = dnnPoolingCreateForward<DType>(&poolingFwd, NULL,
algorithm, fwd_bottom_data->layout_int,
kernel_size, kernel_stride,
src_offset, border_type);
CHECK_EQ(status, E_SUCCESS);
fwd_top_data->create_internal_layout(poolingFwd, dnnResourceDst);
// Now create poolingBwd
status = dnnPoolingCreateBackward<DType>(&poolingBwd, NULL,
algorithm, fwd_bottom_data->layout_int,
kernel_size, kernel_stride,
src_offset, border_type);
CHECK_EQ(status, E_SUCCESS);
bwd_top_diff->create_internal_layout(poolingFwd, dnnResourceDst);
bwd_bottom_diff->create_internal_layout(poolingFwd, dnnResourceSrc);
}
}
#endif
if (first_pass) {
dnnLayout_t max_idx_datal = NULL;
status = dnnLayoutCreateFromPrimitive<DType>(
&max_idx_datal, poolingFwd, dnnResourceWorkspace);
CHECK_EQ(status, E_SUCCESS);
status = dnnAllocateBuffer<DType>(reinterpret_cast<void**>(&max_idx_data), max_idx_datal);
CHECK_EQ(status, E_SUCCESS);
#if MKL_EXPERIMENTAL == 0
fwd_bottom_data->create_internal_layout(poolingFwd, dnnResourceSrc);
fwd_top_data->create_internal_layout(poolingFwd, dnnResourceDst);
bwd_top_diff->create_internal_layout(poolingBwd, dnnResourceDiffDst);
bwd_bottom_diff->create_internal_layout(poolingBwd, dnnResourceDiffSrc);
#endif
dnnLayoutDelete<DType>(max_idx_datal);
first_pass = false;
}
pooling_res[dnnResourceSrc] = bottom_data;
pooling_res[dnnResourceWorkspace] = max_idx_data;
pooling_res[dnnResourceDst] = fwd_top_data->get_output_ptr(
out.dptr_, fwd_top_data, out_data[pool_enum::kOut]);
status = dnnExecute<DType>(poolingFwd, pooling_res);
CHECK_EQ(status, E_SUCCESS);
#if MKL_EXPERIMENTAL == 0
if (fwd_top_data->conversion_needed()) {
fwd_top_data->convert_from_prv(out.dptr_);
}
#endif
}