virtual void Forward()

in src/operator/mkl/mkl_concat-inl.h [115:225]


  virtual void Forward(const OpContext &ctx,
                       const std::vector<TBlob> &in_data,
                       const std::vector<OpReqType> &req,
                       const std::vector<TBlob> &out_data,
                       const std::vector<TBlob> &aux_args) {
    using namespace mshadow;
    using namespace mshadow::expr;
    CHECK_EQ(static_cast<int>(in_data.size()), size_);
    CHECK_EQ(out_data.size(), 1);
    CHECK_LT(dimension_, (size_t)in_data[concat_enum::kData0].ndim());
    Stream<xpu> *s = ctx.get_stream<xpu>();
    std::vector<Tensor<xpu, 4, DType> > data(size_);
    Tensor<xpu, 4, DType> out;
    if (in_data[0].ndim() == 2) {
      for (int i = 0; i < size_; ++i) {
        Shape<4> dshape = Shape4(in_data[i].shape_[0],
                                 in_data[i].shape_[1], 1, 1);
        data[i] = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
          in_data[i], dshape, s);
      }
      Shape<4> dshape = Shape4(out_data[concat_enum::kOut].shape_[0],
                               out_data[concat_enum::kOut].shape_[1], 1, 1);
      out = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
        out_data[concat_enum::kOut], dshape, s);
    } else if (in_data[0].ndim() == 3) {
      for (int i = 0; i < size_; ++i) {
        Shape<4> dshape = Shape4(in_data[i].shape_[0],
          in_data[i].shape_[1], in_data[i].shape_[2], 1);
        data[i] = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
          in_data[i], dshape, s);
      }
      Shape<4> dshape = Shape4(out_data[concat_enum::kOut].shape_[0],
        out_data[concat_enum::kOut].shape_[1],
        out_data[concat_enum::kOut].shape_[2], 1);
      out = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
        out_data[concat_enum::kOut], dshape, s);
    } else {
      for (int i = 0; i < size_; ++i) {
        data[i] = mkl_experimental_direct_get<xpu, 4, DType>(in_data[i], s);
      }
      out = mkl_experimental_direct_get<xpu, 4, DType>(out_data[concat_enum::kOut], s);
    }
    size_t *split_channels_ = new size_t[num_concats_];
    if (!init_mkldnn_) {
      init_mkldnn_ = true;
      LayerSetUp(data, out, 4, split_channels_);
    }

    dnnError_t e;
    std::vector<void*> bottom_data;
    bool isFirstPass = (concatFwd_ == NULL);
    dnnLayout_t *layouts = NULL;
    if (isFirstPass) {
      layouts = new dnnLayout_t[num_concats_];
    }

    for (size_t i = 0; i < num_concats_; i++) {
      void * bottom_i = NULL;
#if MKL_EXPERIMENTAL == 1
      bottom_i = mkl_prv_data<DType>(in_data[i]);
      if (bottom_i != NULL) {
        if (isFirstPass) {
          std::shared_ptr<MKLData<DType> > mem_descr =
            mkl_get_mem_desc<DType>(in_data[i].Mkl_mem_);
          fwd_bottom_data_[i] = mem_descr;
          layouts[i] = mem_descr->layout_int;
        }
      }
#endif
      if (bottom_i == NULL) {
        bottom_i = data[i].dptr_;
        if (isFirstPass) {
          layouts[i] = fwd_bottom_data_[i]->layout_usr;
        }
      }

      bottom_data.push_back(reinterpret_cast<void *>(bottom_i));
    }

    if (isFirstPass) {
      e = dnnConcatCreate<DType>(&concatFwd_, NULL, num_concats_, layouts);
      CHECK_EQ(e, E_SUCCESS);

      fwd_top_data_->create_internal_layout(concatFwd_, dnnResourceDst);
      bwd_top_diff_->create_internal_layout(concatFwd_, dnnResourceDst);

      e = dnnSplitCreate<DType>(&concatBwd_, NULL, num_concats_,
            bwd_top_diff_->layout_int, split_channels_);
      CHECK_EQ(e, E_SUCCESS);

      for (size_t n = 0; n < num_concats_; ++n) {
        fwd_bottom_data_[n]->create_internal_layout(concatFwd_,
          (dnnResourceType_t)(dnnResourceMultipleSrc + n));
        bwd_bottom_diff_[n]->create_internal_layout(concatBwd_,
          (dnnResourceType_t)(dnnResourceMultipleDst + n));
      }
    }
    delete[] layouts;

    void *concat_res[dnnResourceNumber];
    for (size_t i = 0; i < num_concats_; ++i) {
      concat_res[dnnResourceMultipleSrc + i]
        = reinterpret_cast<void*>(bottom_data[i]);
    }

    concat_res[dnnResourceDst] = fwd_top_data_->get_output_ptr(out.dptr_,
      fwd_top_data_, out_data[concat_enum::kOut]);
    e = dnnExecute<DType>(concatFwd_, concat_res);
    CHECK_EQ(e, E_SUCCESS);
    delete[] split_channels_;
  }