inline void CorrelationBackward()

in src/operator/correlation.cc [64:126]


inline void CorrelationBackward(const Tensor<cpu, 4, Dtype> &out_grad,
                                const Tensor<cpu, 4, Dtype> &in_grad1,
                                const Tensor<cpu, 4, Dtype> &in_grad2,
                                const Tensor<cpu, 4, Dtype> &tmp1,
                                const Tensor<cpu, 4, Dtype> &tmp2,
                                int top_channels_, int top_height_,
                                int top_width_, int pad_size_,
                                bool is_multiply, int max_displacement_,
                                int kernel_size_, int neighborhood_grid_radius_,
                                int neighborhood_grid_width_,
                                int  kernel_radius_, int stride1_,
                                int stride2_, int num,
                                int channels, int height, int width
                            ) {
  const float sumelems = kernel_size_ * kernel_size_ * channels;
  for (int i = 0 ; i < static_cast<index_t>(top_height_) ; i++)
     for (int j = 0 ; j < static_cast<index_t>(top_width_); j++)
        for (int nbatch = 0 ; nbatch < static_cast<index_t>(num) ; nbatch++) {
            int x1 = j*stride1_+max_displacement_;
            int y1 = i*stride1_+max_displacement_;
            for (int top_channel = 0 ; top_channel < top_channels_ ; top_channel++) {
              int s2o = (top_channel % neighborhood_grid_width_ - \
              neighborhood_grid_radius_) * stride2_;
              int s2p = (top_channel / neighborhood_grid_width_ - \
              neighborhood_grid_radius_) * stride2_;
              int x2 = x1 + s2o;
              int y2 = y1 + s2p;
              for (int h = 0; h < kernel_size_; h++)
                for (int w = 0; w < kernel_size_; w++)
                  for (int channel = 0 ; channel < channels; channel++) {
                    if (is_multiply == true) {
                      if ((y1 +  h - pad_size_ >= 0) && (x1 + w - pad_size_ >= 0) && \
                      (y1 + h < height +pad_size_) && (x1 + w < width + pad_size_)) {
                        in_grad1[nbatch][channel][y1+h-pad_size_][x1+w-pad_size_] += \
                        out_grad[nbatch][top_channel][i][j] * \
                        tmp2[nbatch][y2+h][x2+w][channel]/sumelems;
                       }
                       if ((y2 +  h - pad_size_ >= 0) && (x2 + w -pad_size_ >=0) && \
                       (y2 + h < height +pad_size_) && (x2 + w < width + pad_size_)) {
                       in_grad2[nbatch][channel][y2+h-pad_size_][x2+w-pad_size_] += \
                       out_grad[nbatch][top_channel][i][j] * \
                       tmp1[nbatch][y1+h][x1+w][channel]/sumelems;
                       }
                    } else {
                      if ((y1 +  h - pad_size_ >= 0) && (x1 + w -pad_size_ >=0) && \
                      (y1 + h < height + pad_size_) && (x1 + w < width + pad_size_)) {
                        Dtype sign  = (tmp1[nbatch][y1+h][x1+w][channel] >= \
                        tmp2[nbatch][y2+h][x2+w][channel])? Dtype(1.0) : Dtype(-1.0);
                        in_grad1[nbatch][channel][y1+h-pad_size_][x1+w-pad_size_] +=\
                        out_grad[nbatch][top_channel][i][j]*sign/sumelems;
                      }
                      if ((y2 +  h - pad_size_ >= 0) && (x2 + w - pad_size_ >=0) && \
                      (y2 + h < height + pad_size_) && (x2 + w < width + pad_size_)) {
                        Dtype sign  = (tmp1[nbatch][y1+h][x1+w][channel] >= \
                        tmp2[nbatch][y2+h][x2+w][channel])? Dtype(-1.0) : Dtype(1.0);
                        in_grad2[nbatch][channel][y2+h-pad_size_][x2+w-pad_size_] +=\
                        out_grad[nbatch][top_channel][i][j]*sign/sumelems;
                       }
                    }
                  }
               }
         }
}