Tensor CpuConvBackwardW()

in src/model/operation/convolution.cc [313:409]


Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W,
                        const ConvHandle &ch) {
  CHECK_EQ(dy.device()->lang(), kCpp);
  CHECK_EQ(x.device()->lang(), kCpp);
  CHECK_EQ(W.device()->lang(), kCpp);

  CHECK(dy.shape(1) == ch.num_filters && dy.shape(2) == ch.conv_height &&
        dy.shape(3) == ch.conv_width)
      << "input gradients shape should not change";

  CHECK(x.shape(1) == ch.channels && x.shape(2) == ch.height &&
        x.shape(3) == ch.width)
      << "input sample shape should not change";

#ifdef USE_DNNL
  Tensor dW;
  dW.ResetLike(W);

  dy.device()->Exec(
      [dy, dW, x, &W, &ch](Context *ctx) mutable {
        using namespace dnnl;
        auto eng = ctx->dnnl_engine;
        auto s = ctx->dnnl_stream;
        using tag = memory::format_tag;
        auto dtype = dnnl::memory::data_type::f32;

        auto conv_src_md = memory::desc({ch.x_dims}, dtype, tag::nchw);
        auto conv_weights_md = memory::desc({ch.w_dims}, dtype, tag::goihw);
        auto conv_bias_md = memory::desc({ch.b_dims}, dtype, tag::x);
        auto conv_dst_md = memory::desc({ch.o_dims}, dtype, tag::nchw);

        auto conv_user_src_memory =
            memory(conv_src_md, eng, x.block()->mutable_data());
        auto conv_user_diff_weights_memory =
            memory(conv_weights_md, eng, dW.block()->mutable_data());
        auto conv_diff_bias_memory =
            memory(conv_bias_md, eng, ch.db->block()->mutable_data());
        auto conv_user_diff_dst_memory =
            memory(conv_dst_md, eng, dy.block()->mutable_data());

        auto conv_desc = convolution_forward::desc(
            prop_kind::forward, algorithm::convolution_direct, conv_src_md,
            conv_weights_md, conv_bias_md, conv_dst_md, ch.s_dims, ch.p_dims,
            ch.p_dims);
        auto conv_pd = convolution_forward::primitive_desc(conv_desc, eng);

        // auto conv_pd = *ch.conv_pd; // very slow

        auto conv_bwd_src_memory = conv_user_src_memory;
        auto conv_diff_weights_memory = conv_user_diff_weights_memory;
        auto conv_diff_dst_memory = conv_user_diff_dst_memory;

        auto conv_bwd_weights_desc = convolution_backward_weights::desc(
            algorithm::convolution_direct, conv_src_md, conv_weights_md,
            conv_bias_md, conv_dst_md, ch.s_dims, ch.p_dims, ch.p_dims);
        auto conv_bwd_weights_pd = convolution_backward_weights::primitive_desc(
            conv_bwd_weights_desc, eng, conv_pd);

        convolution_backward_weights(conv_bwd_weights_pd)
            .execute(ctx->dnnl_stream,
                     {{DNNL_ARG_DIFF_DST, conv_diff_dst_memory},
                      {DNNL_ARG_SRC, conv_bwd_src_memory},
                      {DNNL_ARG_DIFF_WEIGHTS, conv_diff_weights_memory},
                      {DNNL_ARG_DIFF_BIAS, conv_diff_bias_memory}});
        ctx->dnnl_stream.wait();
      },
      {x.block(), dy.block(), W.block()}, {dW.block(), ch.db->block()},
      "CpuConvBackwardW");

  return dW;
#else   // native cpp
/* // error due to importing Im2col
  Tensor dW;
  dW.ResetLike(W);
  dW.SetValue(0.0f);

  Shape w_shape = W.shape();
  dW.Reshape(Shape{ch.num_filters, ch.col_height});

  Tensor col_data(Shape{ch.col_height, ch.col_width});  // broadcasted image

  float *data_col = new float[ch.col_height * ch.col_width];
  auto in_data = dy.data<float>();
  for (size_t num = 0; num < ch.batchsize; num++) {
    Im2col(in_data + num * ch.imagesize, ch.channels, ch.height, ch.width,
           ch.kernel_h, ch.kernel_w, ch.pad_h, ch.pad_w, ch.stride_h,
           ch.stride_w, data_col);
    col_data.CopyDataFromHostPtr(data_col, ch.col_height * ch.col_width);
    Tensor grad_b(Shape{ch.num_filters, ch.conv_height * ch.conv_width});
    CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size());
    dW += Mult(grad_b, Transpose(col_data));
  }
  dW.Reshape(w_shape);
  return dW;
*/
#endif  // USE_DNNL
}