std::tuple roi_sampling_forward_cpu()

in src/roi_sampling/roi_sampling_cpu.cpp [47:77]


std::tuple<at::Tensor, at::Tensor> roi_sampling_forward_cpu(
    const at::Tensor& x, const at::Tensor& bbx, const at::Tensor& idx, std::tuple<int, int> out_size,
    Interpolation interpolation, PaddingMode padding, bool valid_mask) {

  // Prepare outputs
  auto y = at::empty({idx.size(0), x.size(1), std::get<0>(out_size), std::get<1>(out_size)}, x.options());
  auto mask = valid_mask
      ? at::zeros({idx.size(0), std::get<0>(out_size), std::get<1>(out_size)}, x.options().dtype(at::kByte))
      : at::zeros({1, 1, 1}, x.options().dtype(at::kByte));

  AT_DISPATCH_ALL_TYPES(x.scalar_type(), "roi_sampling_forward_cpu", ([&] {
    using coord_t = float;
    using index_t = int64_t;

    auto _x = x.accessor<scalar_t, 4>();
    auto _bbx = bbx.accessor<coord_t, 2>();
    auto _idx = idx.accessor<index_t, 1>();
    auto _y = y.accessor<scalar_t, 4>();
    auto _mask = mask.accessor<uint8_t, 3>();

    DISPATCH_INTERPOLATION_PADDING_MODES(interpolation, padding, ([&] {
      indexer_t indexer(x.size(2), x.size(3));
      interpolator_t interpolator;
      sampler_t sampler(indexer, interpolator);

      roi_sampling_forward_impl<scalar_t, coord_t, sampler_t>(_x, _bbx, _idx, _y, _mask, valid_mask, sampler);
    }));
  }));

  return std::make_tuple(y, mask);
}