Tensor CpuConvBackwardx()

in src/model/operation/convolution.cc [227:311]


Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x,
                        const ConvHandle &ch) {
  CHECK_EQ(dy.device()->lang(), kCpp);
  CHECK_EQ(W.device()->lang(), kCpp);
  CHECK_EQ(x.device()->lang(), kCpp);

  CHECK(dy.shape(1) == ch.num_filters && dy.shape(2) == ch.conv_height &&
        dy.shape(3) == ch.conv_width)
      << "input gradients shape should not change";

  CHECK(W.shape(0) == ch.num_filters && W.shape(1) == ch.channels &&
        W.shape(2) == ch.kernel_h && W.shape(3) == ch.kernel_w)
      << "weights shape should not change";

#ifdef USE_DNNL
  Tensor dx;
  dx.ResetLike(x);

  dy.device()->Exec(
      [dx, dy, x, &W, &ch](Context *ctx) mutable {
        using namespace dnnl;
        auto eng = ctx->dnnl_engine;
        auto s = ctx->dnnl_stream;
        using tag = memory::format_tag;
        auto dtype = dnnl::memory::data_type::f32;

        auto conv_src_md = memory::desc({ch.x_dims}, dtype, tag::nchw);
        auto conv_weights_md = memory::desc({ch.w_dims}, dtype, tag::goihw);
        auto conv_bias_md = memory::desc({ch.b_dims}, dtype, tag::x);
        auto conv_dst_md = memory::desc({ch.o_dims}, dtype, tag::nchw);

        auto conv_user_src_memory =
            memory(conv_src_md, eng, dx.block()->mutable_data());
        auto conv_user_diff_dst_memory =
            memory(conv_dst_md, eng, dy.block()->mutable_data());
        auto conv_user_weights_memory =
            memory(conv_weights_md, eng, W.block()->mutable_data());

        auto conv_desc = convolution_forward::desc(
            prop_kind::forward, algorithm::convolution_direct, conv_src_md,
            conv_weights_md, conv_bias_md, conv_dst_md, ch.s_dims, ch.p_dims,
            ch.p_dims);
        auto conv_pd = convolution_forward::primitive_desc(conv_desc, eng);

        auto conv_bwd_data_d = convolution_backward_data::desc(
            algorithm::convolution_direct, conv_src_md, conv_weights_md,
            conv_dst_md, ch.s_dims, ch.p_dims, ch.p_dims);
        auto conv_bwd_data_pd = convolution_backward_data::primitive_desc(
            conv_bwd_data_d, eng, conv_pd);

        convolution_backward_data(conv_bwd_data_pd)
            .execute(ctx->dnnl_stream,
                     {{DNNL_ARG_DIFF_DST, conv_user_diff_dst_memory},
                      {DNNL_ARG_WEIGHTS, conv_user_weights_memory},
                      {DNNL_ARG_DIFF_SRC, conv_user_src_memory}});
        ctx->dnnl_stream.wait();
      },
      {x.block(), dy.block(), W.block()}, {dx.block()}, "CpuConvBackwardx");

  return dx;

#else   // NOT USE_DNNL
/*  // error due to importing Col2im
  Shape w_shape = W.shape();
  W.Reshape(Shape{ch.num_filters, ch.col_height});

  Tensor dx;
  dx.ResetLike(x);

  float *dx_b = new float[ch.imagesize];

  for (size_t num = 0; num < ch.batchsize; num++) {
    Tensor grad_b(Shape{ch.num_filters, ch.conv_height * ch.conv_width});
    CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size());
    Tensor dcol_b = Mult(Transpose(W), grad_b);
    auto dcol_data = dcol_b.data<float>();
    Col2im(dcol_data, ch.channels, ch.height, ch.width, ch.kernel_h,
           ch.kernel_w, ch.pad_h, ch.pad_w, ch.stride_h, ch.stride_w, dx_b);
    dx.CopyDataFromHostPtr(dx_b, ch.imagesize, num * ch.imagesize);
  }
  W.Reshape(w_shape);
  return dx;
*/
#endif  // USE_DNNL
}