in src/model/operation/convolution.cc [227:311]
Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x,
const ConvHandle &ch) {
CHECK_EQ(dy.device()->lang(), kCpp);
CHECK_EQ(W.device()->lang(), kCpp);
CHECK_EQ(x.device()->lang(), kCpp);
CHECK(dy.shape(1) == ch.num_filters && dy.shape(2) == ch.conv_height &&
dy.shape(3) == ch.conv_width)
<< "input gradients shape should not change";
CHECK(W.shape(0) == ch.num_filters && W.shape(1) == ch.channels &&
W.shape(2) == ch.kernel_h && W.shape(3) == ch.kernel_w)
<< "weights shape should not change";
#ifdef USE_DNNL
Tensor dx;
dx.ResetLike(x);
dy.device()->Exec(
[dx, dy, x, &W, &ch](Context *ctx) mutable {
using namespace dnnl;
auto eng = ctx->dnnl_engine;
auto s = ctx->dnnl_stream;
using tag = memory::format_tag;
auto dtype = dnnl::memory::data_type::f32;
auto conv_src_md = memory::desc({ch.x_dims}, dtype, tag::nchw);
auto conv_weights_md = memory::desc({ch.w_dims}, dtype, tag::goihw);
auto conv_bias_md = memory::desc({ch.b_dims}, dtype, tag::x);
auto conv_dst_md = memory::desc({ch.o_dims}, dtype, tag::nchw);
auto conv_user_src_memory =
memory(conv_src_md, eng, dx.block()->mutable_data());
auto conv_user_diff_dst_memory =
memory(conv_dst_md, eng, dy.block()->mutable_data());
auto conv_user_weights_memory =
memory(conv_weights_md, eng, W.block()->mutable_data());
auto conv_desc = convolution_forward::desc(
prop_kind::forward, algorithm::convolution_direct, conv_src_md,
conv_weights_md, conv_bias_md, conv_dst_md, ch.s_dims, ch.p_dims,
ch.p_dims);
auto conv_pd = convolution_forward::primitive_desc(conv_desc, eng);
auto conv_bwd_data_d = convolution_backward_data::desc(
algorithm::convolution_direct, conv_src_md, conv_weights_md,
conv_dst_md, ch.s_dims, ch.p_dims, ch.p_dims);
auto conv_bwd_data_pd = convolution_backward_data::primitive_desc(
conv_bwd_data_d, eng, conv_pd);
convolution_backward_data(conv_bwd_data_pd)
.execute(ctx->dnnl_stream,
{{DNNL_ARG_DIFF_DST, conv_user_diff_dst_memory},
{DNNL_ARG_WEIGHTS, conv_user_weights_memory},
{DNNL_ARG_DIFF_SRC, conv_user_src_memory}});
ctx->dnnl_stream.wait();
},
{x.block(), dy.block(), W.block()}, {dx.block()}, "CpuConvBackwardx");
return dx;
#else // NOT USE_DNNL
/* // error due to importing Col2im
Shape w_shape = W.shape();
W.Reshape(Shape{ch.num_filters, ch.col_height});
Tensor dx;
dx.ResetLike(x);
float *dx_b = new float[ch.imagesize];
for (size_t num = 0; num < ch.batchsize; num++) {
Tensor grad_b(Shape{ch.num_filters, ch.conv_height * ch.conv_width});
CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size());
Tensor dcol_b = Mult(Transpose(W), grad_b);
auto dcol_data = dcol_b.data<float>();
Col2im(dcol_data, ch.channels, ch.height, ch.width, ch.kernel_h,
ch.kernel_w, ch.pad_h, ch.pad_w, ch.stride_h, ch.stride_w, dx_b);
dx.CopyDataFromHostPtr(dx_b, ch.imagesize, num * ch.imagesize);
}
W.Reshape(w_shape);
return dx;
*/
#endif // USE_DNNL
}