Status operator()

in tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.cc [41:118]


  Status operator()(OpKernelContext* context, const Tensor& input_a_t,
                    const Tensor& input_b_t, Tensor* output_t,
                    /* params */
                    int kernel_size, int max_displacement, int stride_1,
                    int stride_2, int pad, TensorFormat data_format) {
    const int32 oN = GetTensorDim(*output_t, FORMAT_NCHW, 'N');
    // const int32 oC = GetTensorDim(*output_t, FORMAT_NCHW, 'C');
    const int32 oH = GetTensorDim(*output_t, FORMAT_NCHW, 'H');
    const int32 oW = GetTensorDim(*output_t, FORMAT_NCHW, 'W');
    const int32 iH = GetTensorDim(input_a_t, data_format, 'H');
    const int32 iW = GetTensorDim(input_a_t, data_format, 'W');
    const int32 iC = GetTensorDim(input_a_t, data_format, 'C');

    const int K = kernel_size * kernel_size * iC;

    const auto input_a = input_a_t.tensor<Dtype, 4>();
    const auto input_b = input_b_t.tensor<Dtype, 4>();
    auto output = output_t->tensor<Dtype, 4>();
    output.setZero();

    const int kernel_rad = (kernel_size - 1) / 2;
    const int displacement_rad = max_displacement / stride_2;
    const int displacement_size = 2 * displacement_rad + 1;

    const bool is_NCHW = (data_format == FORMAT_NCHW);
    // estimate operations per pixel
    const int64 cost_per_pixel =
        iC * ((2 * displacement_rad + 1) * (2 * displacement_rad + 1)) *
        ((2 * kernel_rad + 1) * (2 * kernel_rad + 1)) *
        (Eigen::TensorOpCost::MulCost<Dtype>() +
         Eigen::TensorOpCost::AddCost<Dtype>());

    const auto work = [&](Eigen::Index start, Eigen::Index end) -> void {
      for (Eigen::Index id = start; id < end; ++id) {
        const int n = id / (oH * oW);
        const int h = (id / oW) % oH;
        const int w = id % oW;
        const int h1 = (h - pad) * stride_1 + max_displacement + kernel_rad;
        const int w1 = (w - pad) * stride_1 + max_displacement + kernel_rad;
        for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) {
          for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) {
            const int tc = (tj + displacement_rad) * displacement_size +
                           (ti + displacement_rad);

            const int w2 = w1 + ti * stride_2;
            const int h2 = h1 + tj * stride_2;

            for (int j = -kernel_rad; j <= kernel_rad; ++j) {
              // out-of-bound test
              if (!FastBoundsCheck(h1 + j, iH) || !FastBoundsCheck(h2 + j, iH))
                continue;
              for (int i = -kernel_rad; i <= kernel_rad; ++i) {
                if (!FastBoundsCheck(w1 + i, iW) ||
                    !FastBoundsCheck(w2 + i, iW))
                  continue;
                for (int c = 0; c < iC; ++c) {
                  // eq. (1) in FlowNet: Learning Optical Flow with
                  // Convolutional Networks
                  if (is_NCHW) {
                    output(n, tc, h, w) += input_a(n, c, h1 + j, w1 + i) *
                                           input_b(n, c, h2 + j, w2 + i);
                  } else {
                    output(n, tc, h, w) += input_a(n, h1 + j, w1 + i, c) *
                                           input_b(n, h2 + j, w2 + i, c);
                  }
                }
              }
            }
            output(n, tc, h, w) /= K;
          }
        }
      }
    };
    auto thread_pool =
        context->device()->tensorflow_cpu_worker_threads()->workers;
    thread_pool->ParallelFor(oN * oH * oW, cost_per_pixel, work);
    return Status::OK();
  }