in src/operator/correlation.cc [64:126]
inline void CorrelationBackward(const Tensor<cpu, 4, Dtype> &out_grad,
const Tensor<cpu, 4, Dtype> &in_grad1,
const Tensor<cpu, 4, Dtype> &in_grad2,
const Tensor<cpu, 4, Dtype> &tmp1,
const Tensor<cpu, 4, Dtype> &tmp2,
int top_channels_, int top_height_,
int top_width_, int pad_size_,
bool is_multiply, int max_displacement_,
int kernel_size_, int neighborhood_grid_radius_,
int neighborhood_grid_width_,
int kernel_radius_, int stride1_,
int stride2_, int num,
int channels, int height, int width
) {
const float sumelems = kernel_size_ * kernel_size_ * channels;
for (int i = 0 ; i < static_cast<index_t>(top_height_) ; i++)
for (int j = 0 ; j < static_cast<index_t>(top_width_); j++)
for (int nbatch = 0 ; nbatch < static_cast<index_t>(num) ; nbatch++) {
int x1 = j*stride1_+max_displacement_;
int y1 = i*stride1_+max_displacement_;
for (int top_channel = 0 ; top_channel < top_channels_ ; top_channel++) {
int s2o = (top_channel % neighborhood_grid_width_ - \
neighborhood_grid_radius_) * stride2_;
int s2p = (top_channel / neighborhood_grid_width_ - \
neighborhood_grid_radius_) * stride2_;
int x2 = x1 + s2o;
int y2 = y1 + s2p;
for (int h = 0; h < kernel_size_; h++)
for (int w = 0; w < kernel_size_; w++)
for (int channel = 0 ; channel < channels; channel++) {
if (is_multiply == true) {
if ((y1 + h - pad_size_ >= 0) && (x1 + w - pad_size_ >= 0) && \
(y1 + h < height +pad_size_) && (x1 + w < width + pad_size_)) {
in_grad1[nbatch][channel][y1+h-pad_size_][x1+w-pad_size_] += \
out_grad[nbatch][top_channel][i][j] * \
tmp2[nbatch][y2+h][x2+w][channel]/sumelems;
}
if ((y2 + h - pad_size_ >= 0) && (x2 + w -pad_size_ >=0) && \
(y2 + h < height +pad_size_) && (x2 + w < width + pad_size_)) {
in_grad2[nbatch][channel][y2+h-pad_size_][x2+w-pad_size_] += \
out_grad[nbatch][top_channel][i][j] * \
tmp1[nbatch][y1+h][x1+w][channel]/sumelems;
}
} else {
if ((y1 + h - pad_size_ >= 0) && (x1 + w -pad_size_ >=0) && \
(y1 + h < height + pad_size_) && (x1 + w < width + pad_size_)) {
Dtype sign = (tmp1[nbatch][y1+h][x1+w][channel] >= \
tmp2[nbatch][y2+h][x2+w][channel])? Dtype(1.0) : Dtype(-1.0);
in_grad1[nbatch][channel][y1+h-pad_size_][x1+w-pad_size_] +=\
out_grad[nbatch][top_channel][i][j]*sign/sumelems;
}
if ((y2 + h - pad_size_ >= 0) && (x2 + w - pad_size_ >=0) && \
(y2 + h < height + pad_size_) && (x2 + w < width + pad_size_)) {
Dtype sign = (tmp1[nbatch][y1+h][x1+w][channel] >= \
tmp2[nbatch][y2+h][x2+w][channel])? Dtype(-1.0) : Dtype(1.0);
in_grad2[nbatch][channel][y2+h-pad_size_][x2+w-pad_size_] +=\
out_grad[nbatch][top_channel][i][j]*sign/sumelems;
}
}
}
}
}
}