in src/model/operation/convolution.cc [313:409]
Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W,
const ConvHandle &ch) {
CHECK_EQ(dy.device()->lang(), kCpp);
CHECK_EQ(x.device()->lang(), kCpp);
CHECK_EQ(W.device()->lang(), kCpp);
CHECK(dy.shape(1) == ch.num_filters && dy.shape(2) == ch.conv_height &&
dy.shape(3) == ch.conv_width)
<< "input gradients shape should not change";
CHECK(x.shape(1) == ch.channels && x.shape(2) == ch.height &&
x.shape(3) == ch.width)
<< "input sample shape should not change";
#ifdef USE_DNNL
Tensor dW;
dW.ResetLike(W);
dy.device()->Exec(
[dy, dW, x, &W, &ch](Context *ctx) mutable {
using namespace dnnl;
auto eng = ctx->dnnl_engine;
auto s = ctx->dnnl_stream;
using tag = memory::format_tag;
auto dtype = dnnl::memory::data_type::f32;
auto conv_src_md = memory::desc({ch.x_dims}, dtype, tag::nchw);
auto conv_weights_md = memory::desc({ch.w_dims}, dtype, tag::goihw);
auto conv_bias_md = memory::desc({ch.b_dims}, dtype, tag::x);
auto conv_dst_md = memory::desc({ch.o_dims}, dtype, tag::nchw);
auto conv_user_src_memory =
memory(conv_src_md, eng, x.block()->mutable_data());
auto conv_user_diff_weights_memory =
memory(conv_weights_md, eng, dW.block()->mutable_data());
auto conv_diff_bias_memory =
memory(conv_bias_md, eng, ch.db->block()->mutable_data());
auto conv_user_diff_dst_memory =
memory(conv_dst_md, eng, dy.block()->mutable_data());
auto conv_desc = convolution_forward::desc(
prop_kind::forward, algorithm::convolution_direct, conv_src_md,
conv_weights_md, conv_bias_md, conv_dst_md, ch.s_dims, ch.p_dims,
ch.p_dims);
auto conv_pd = convolution_forward::primitive_desc(conv_desc, eng);
// auto conv_pd = *ch.conv_pd; // very slow
auto conv_bwd_src_memory = conv_user_src_memory;
auto conv_diff_weights_memory = conv_user_diff_weights_memory;
auto conv_diff_dst_memory = conv_user_diff_dst_memory;
auto conv_bwd_weights_desc = convolution_backward_weights::desc(
algorithm::convolution_direct, conv_src_md, conv_weights_md,
conv_bias_md, conv_dst_md, ch.s_dims, ch.p_dims, ch.p_dims);
auto conv_bwd_weights_pd = convolution_backward_weights::primitive_desc(
conv_bwd_weights_desc, eng, conv_pd);
convolution_backward_weights(conv_bwd_weights_pd)
.execute(ctx->dnnl_stream,
{{DNNL_ARG_DIFF_DST, conv_diff_dst_memory},
{DNNL_ARG_SRC, conv_bwd_src_memory},
{DNNL_ARG_DIFF_WEIGHTS, conv_diff_weights_memory},
{DNNL_ARG_DIFF_BIAS, conv_diff_bias_memory}});
ctx->dnnl_stream.wait();
},
{x.block(), dy.block(), W.block()}, {dW.block(), ch.db->block()},
"CpuConvBackwardW");
return dW;
#else // native cpp
/* // error due to importing Im2col
Tensor dW;
dW.ResetLike(W);
dW.SetValue(0.0f);
Shape w_shape = W.shape();
dW.Reshape(Shape{ch.num_filters, ch.col_height});
Tensor col_data(Shape{ch.col_height, ch.col_width}); // broadcasted image
float *data_col = new float[ch.col_height * ch.col_width];
auto in_data = dy.data<float>();
for (size_t num = 0; num < ch.batchsize; num++) {
Im2col(in_data + num * ch.imagesize, ch.channels, ch.height, ch.width,
ch.kernel_h, ch.kernel_w, ch.pad_h, ch.pad_w, ch.stride_h,
ch.stride_w, data_col);
col_data.CopyDataFromHostPtr(data_col, ch.col_height * ch.col_width);
Tensor grad_b(Shape{ch.num_filters, ch.conv_height * ch.conv_width});
CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size());
dW += Mult(grad_b, Transpose(col_data));
}
dW.Reshape(w_shape);
return dW;
*/
#endif // USE_DNNL
}