in src/operator/mkl/mkl_concat-inl.h [115:225]
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(static_cast<int>(in_data.size()), size_);
CHECK_EQ(out_data.size(), 1);
CHECK_LT(dimension_, (size_t)in_data[concat_enum::kData0].ndim());
Stream<xpu> *s = ctx.get_stream<xpu>();
std::vector<Tensor<xpu, 4, DType> > data(size_);
Tensor<xpu, 4, DType> out;
if (in_data[0].ndim() == 2) {
for (int i = 0; i < size_; ++i) {
Shape<4> dshape = Shape4(in_data[i].shape_[0],
in_data[i].shape_[1], 1, 1);
data[i] = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
in_data[i], dshape, s);
}
Shape<4> dshape = Shape4(out_data[concat_enum::kOut].shape_[0],
out_data[concat_enum::kOut].shape_[1], 1, 1);
out = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
out_data[concat_enum::kOut], dshape, s);
} else if (in_data[0].ndim() == 3) {
for (int i = 0; i < size_; ++i) {
Shape<4> dshape = Shape4(in_data[i].shape_[0],
in_data[i].shape_[1], in_data[i].shape_[2], 1);
data[i] = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
in_data[i], dshape, s);
}
Shape<4> dshape = Shape4(out_data[concat_enum::kOut].shape_[0],
out_data[concat_enum::kOut].shape_[1],
out_data[concat_enum::kOut].shape_[2], 1);
out = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
out_data[concat_enum::kOut], dshape, s);
} else {
for (int i = 0; i < size_; ++i) {
data[i] = mkl_experimental_direct_get<xpu, 4, DType>(in_data[i], s);
}
out = mkl_experimental_direct_get<xpu, 4, DType>(out_data[concat_enum::kOut], s);
}
size_t *split_channels_ = new size_t[num_concats_];
if (!init_mkldnn_) {
init_mkldnn_ = true;
LayerSetUp(data, out, 4, split_channels_);
}
dnnError_t e;
std::vector<void*> bottom_data;
bool isFirstPass = (concatFwd_ == NULL);
dnnLayout_t *layouts = NULL;
if (isFirstPass) {
layouts = new dnnLayout_t[num_concats_];
}
for (size_t i = 0; i < num_concats_; i++) {
void * bottom_i = NULL;
#if MKL_EXPERIMENTAL == 1
bottom_i = mkl_prv_data<DType>(in_data[i]);
if (bottom_i != NULL) {
if (isFirstPass) {
std::shared_ptr<MKLData<DType> > mem_descr =
mkl_get_mem_desc<DType>(in_data[i].Mkl_mem_);
fwd_bottom_data_[i] = mem_descr;
layouts[i] = mem_descr->layout_int;
}
}
#endif
if (bottom_i == NULL) {
bottom_i = data[i].dptr_;
if (isFirstPass) {
layouts[i] = fwd_bottom_data_[i]->layout_usr;
}
}
bottom_data.push_back(reinterpret_cast<void *>(bottom_i));
}
if (isFirstPass) {
e = dnnConcatCreate<DType>(&concatFwd_, NULL, num_concats_, layouts);
CHECK_EQ(e, E_SUCCESS);
fwd_top_data_->create_internal_layout(concatFwd_, dnnResourceDst);
bwd_top_diff_->create_internal_layout(concatFwd_, dnnResourceDst);
e = dnnSplitCreate<DType>(&concatBwd_, NULL, num_concats_,
bwd_top_diff_->layout_int, split_channels_);
CHECK_EQ(e, E_SUCCESS);
for (size_t n = 0; n < num_concats_; ++n) {
fwd_bottom_data_[n]->create_internal_layout(concatFwd_,
(dnnResourceType_t)(dnnResourceMultipleSrc + n));
bwd_bottom_diff_[n]->create_internal_layout(concatBwd_,
(dnnResourceType_t)(dnnResourceMultipleDst + n));
}
}
delete[] layouts;
void *concat_res[dnnResourceNumber];
for (size_t i = 0; i < num_concats_; ++i) {
concat_res[dnnResourceMultipleSrc + i]
= reinterpret_cast<void*>(bottom_data[i]);
}
concat_res[dnnResourceDst] = fwd_top_data_->get_output_ptr(out.dptr_,
fwd_top_data_, out_data[concat_enum::kOut]);
e = dnnExecute<DType>(concatFwd_, concat_res);
CHECK_EQ(e, E_SUCCESS);
delete[] split_channels_;
}