in src/model/layer/convolution.cc [27:104]
void Convolution::Setup(const Shape &in_sample, const LayerConf &conf) {
Layer::Setup(in_sample, conf);
ConvolutionConf conv_conf = conf.convolution_conf();
// kernel_size, pad, and stride are repeated fields.
if (conv_conf.kernel_size_size() > 0) {
if (conv_conf.kernel_size_size() == 1) {
kernel_w_ = kernel_h_ = conv_conf.kernel_size(0);
} else {
kernel_w_ = conv_conf.kernel_size(0);
kernel_h_ = conv_conf.kernel_size(1);
}
} else {
kernel_w_ = conv_conf.kernel_w();
kernel_h_ = conv_conf.kernel_h();
}
CHECK_GT(kernel_w_, 0u);
CHECK_GT(kernel_h_, 0u);
if (conv_conf.pad_size() > 0) {
if (conv_conf.pad_size() == 1) {
pad_w_ = pad_h_ = conv_conf.pad(0);
} else {
pad_w_ = conv_conf.pad(0);
pad_h_ = conv_conf.pad(1);
}
} else {
pad_w_ = conv_conf.pad_w();
pad_h_ = conv_conf.pad_h();
}
CHECK_GE(pad_w_, 0u);
CHECK_GE(pad_h_, 0u);
const int kStrideDefault = 1;
if (conv_conf.stride_size() > 0) {
if (conv_conf.stride_size() == 1) {
stride_w_ = stride_h_ = conv_conf.stride(0);
} else {
stride_w_ = conv_conf.stride(0);
stride_h_ = conv_conf.stride(1);
}
} else {
stride_w_ = kStrideDefault;
stride_h_ = kStrideDefault;
if (conv_conf.has_stride_w()) {
stride_w_ = conv_conf.stride_w();
}
if (conv_conf.has_stride_h()) {
stride_h_ = conv_conf.stride_h();
}
}
CHECK_GT(stride_w_, 0u);
CHECK_GE(stride_h_, 0u); // 0 for 1D conv
num_filters_ = conv_conf.num_output();
bias_term_ = conv_conf.bias_term();
// Shape of input image
CHECK_EQ(in_sample.size(), 3u);
channels_ = in_sample.at(0);
height_ = in_sample.at(1);
width_ = in_sample.at(2);
conv_height_ = 1;
if (stride_h_ > 0)
conv_height_ = (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1;
conv_width_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1;
out_sample_shape_ = vector<size_t>{num_filters_, conv_height_, conv_width_};
col_height_ = channels_ * kernel_w_ * kernel_h_;
col_width_ = conv_height_ * conv_width_;
// Setup shape of weight_ and bias_
weight_.Resize(Shape{num_filters_, col_height_});
if (bias_term_)
bias_.Resize(Shape{num_filters_});
// Assume the order of param is: weight, bias
for (const auto &spec : conf.param()) param_specs_.push_back(spec);
}