include/roi_sampling.h (180 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates.
#pragma once
#include <type_traits>
#include <tuple>
#include <ATen/ATen.h>
#include "utils/common.h"
// ENUMS
enum class PaddingMode { Zero, Border };
enum class Interpolation { Bilinear, Nearest };
// PROTOTYPES
std::tuple<at::Tensor, at::Tensor> roi_sampling_forward_cpu(
const at::Tensor& x, const at::Tensor& bbx, const at::Tensor& idx, std::tuple<int, int> out_size,
Interpolation interpolation, PaddingMode padding, bool valid_mask);
std::tuple<at::Tensor, at::Tensor> roi_sampling_forward_cuda(
const at::Tensor& x, const at::Tensor& bbx, const at::Tensor& idx, std::tuple<int, int> out_size,
Interpolation interpolation, PaddingMode padding, bool valid_mask);
at::Tensor roi_sampling_backward_cpu(
const at::Tensor& dy, const at::Tensor& bbx, const at::Tensor& idx, std::tuple<int, int, int> in_size,
Interpolation interpolation, PaddingMode padding);
at::Tensor roi_sampling_backward_cuda(
const at::Tensor& dy, const at::Tensor& bbx, const at::Tensor& idx, std::tuple<int, int, int> in_size,
Interpolation interpolation, PaddingMode padding);
/* CONVENTIONS
*
* Integer indexes are i (vertical), j (horizontal) and k (generic)
* Continuous coordinates are y (vertical), x (horizontal) and s (generic)
*
* The relation between the two is: y = i + 0.5, x = j + 0.5
*/
// SAMPLER
template<typename scalar_t, typename coord_t, typename index_t, typename Indexer, typename Interpolator>
struct Sampler {
Sampler(Indexer indexer, Interpolator interpolator) : _indexer(indexer), _interpolator(interpolator) {}
template<typename Accessor>
HOST_DEVICE scalar_t forward(coord_t y, coord_t x, Accessor accessor) const {
// Step 1: find the four indices of the points to read from the input and their offsets
index_t i_l, i_h, j_l, j_h;
coord_t delta_y, delta_x;
_neighbors(y, i_l, i_h, delta_y);
_neighbors(x, j_l, j_h, delta_x);
// Step 2: read the four points
scalar_t p_ll = _indexer.get(accessor, i_l, j_l),
p_lh = _indexer.get(accessor, i_l, j_h),
p_hl = _indexer.get(accessor, i_h, j_l),
p_hh = _indexer.get(accessor, i_h, j_h);
// Step 3: get the interpolated value
return _interpolator.get(delta_y, delta_x, p_ll, p_lh, p_hl, p_hh);
}
template<typename Accessor>
HOST_DEVICE void backward(coord_t y, coord_t x, scalar_t grad, Accessor accessor) const {
// Step 1: find the four indices of the points to read from the input and their offsets
index_t i_l, i_h, j_l, j_h;
coord_t delta_y, delta_x;
_neighbors(y, i_l, i_h, delta_y);
_neighbors(x, j_l, j_h, delta_x);
// Step 2: reverse-interpolation
scalar_t p_ll, p_lh, p_hl, p_hh;
_interpolator.set(delta_y, delta_x, grad, p_ll, p_lh, p_hl, p_hh);
// Step 3: accumulate
_indexer.set(accessor, i_l, j_l, p_ll);
_indexer.set(accessor, i_l, j_h, p_lh);
_indexer.set(accessor, i_h, j_l, p_hl);
_indexer.set(accessor, i_h, j_h, p_hh);
}
private:
INLINE_HOST_DEVICE void _neighbors(coord_t s, index_t &k_l, index_t &k_h, coord_t &delta) const {
k_l = static_cast<index_t>(FLOOR(s - 0.5));
k_h = k_l + 1;
delta = s - (static_cast<coord_t>(k_l) + 0.5);
}
private:
Indexer _indexer;
Interpolator _interpolator;
};
// INDEXER
template<typename index_t>
struct IndexerBase {
IndexerBase(index_t height, index_t width) : _height(height), _width(width) {};
index_t _height;
index_t _width;
};
template<typename scalar_t, typename index_t, PaddingMode padding>
struct Indexer;
template<typename scalar_t, typename index_t>
struct Indexer<scalar_t, index_t, PaddingMode::Zero> : IndexerBase<index_t> {
using IndexerBase<index_t>::IndexerBase;
template<typename Accessor>
INLINE_HOST_DEVICE scalar_t get(Accessor accessor, index_t i, index_t j) const {
return _in_bounds(i, this->_height) && _in_bounds(j, this->_width) ? accessor[i][j] : 0;
}
template<typename Accessor>
INLINE_HOST_DEVICE void set(Accessor accessor, index_t i, index_t j, scalar_t value) const {
if (_in_bounds(i, this->_height) && _in_bounds(j, this->_width)) {
ACCUM_BLOCK(accessor[i][j], value);
}
}
private:
INLINE_HOST_DEVICE bool _in_bounds(index_t k, index_t size) const {
return k >= 0 && k < size;
}
};
template<typename scalar_t, typename index_t>
struct Indexer<scalar_t, index_t, PaddingMode::Border> : IndexerBase<index_t> {
using IndexerBase<index_t>::IndexerBase;
template<typename Accessor>
INLINE_HOST_DEVICE scalar_t get(Accessor accessor, index_t i, index_t j) const {
_clamp(i, j);
return accessor[i][j];
}
template<typename Accessor>
INLINE_HOST_DEVICE void set(Accessor accessor, index_t i, index_t j, scalar_t value) const {
_clamp(i, j);
ACCUM_BLOCK(accessor[i][j], value);
}
private:
INLINE_HOST_DEVICE void _clamp(index_t &i, index_t &j) const {
i = i >= 0 ? i : 0;
i = i < this->_height ? i : this->_height - 1;
j = j >= 0 ? j : 0;
j = j < this->_width ? j : this->_width - 1;
}
};
// INTERPOLATORS
template<typename scalar_t, typename coord_t, Interpolation interpolation>
struct Interpolator;
template<typename scalar_t, typename coord_t>
struct Interpolator<scalar_t, coord_t, Interpolation::Bilinear> {
INLINE_HOST_DEVICE scalar_t get(
coord_t delta_y, coord_t delta_x, scalar_t p_ll, scalar_t p_lh, scalar_t p_hl, scalar_t p_hh) const {
scalar_t hor_int_l = (1 - delta_x) * p_ll + delta_x * p_lh;
scalar_t hor_int_h = (1 - delta_x) * p_hl + delta_x * p_hh;
return (1 - delta_y) * hor_int_l + delta_y * hor_int_h;
}
INLINE_HOST_DEVICE void set(
coord_t delta_y, coord_t delta_x, scalar_t value,
scalar_t &p_ll, scalar_t &p_lh, scalar_t &p_hl, scalar_t &p_hh) const {
p_ll = (1 - delta_x) * (1 - delta_y) * value;
p_lh = delta_x * (1 - delta_y) * value;
p_hl = (1 - delta_x) * delta_y * value;
p_hh = delta_x * delta_y * value;
}
};
template<typename scalar_t, typename coord_t>
struct Interpolator<scalar_t, coord_t, Interpolation::Nearest> {
INLINE_HOST_DEVICE scalar_t get(
coord_t delta_y, coord_t delta_x, scalar_t p_ll, scalar_t p_lh, scalar_t p_hl, scalar_t p_hh) const {
return p_ll * static_cast<scalar_t>(delta_y < 0.5 && delta_x < 0.5) +
p_lh * static_cast<scalar_t>(delta_y < 0.5 && delta_x >= 0.5) +
p_hl * static_cast<scalar_t>(delta_y >= 0.5 && delta_x < 0.5) +
p_hh * static_cast<scalar_t>(delta_y >= 0.5 && delta_x >= 0.5);
}
INLINE_HOST_DEVICE void set(
coord_t delta_y, coord_t delta_x, scalar_t value,
scalar_t &p_ll, scalar_t &p_lh, scalar_t &p_hl, scalar_t &p_hh) const {
p_ll = static_cast<scalar_t>(delta_y < 0.5 && delta_x < 0.5) * value;
p_lh = static_cast<scalar_t>(delta_y < 0.5 && delta_x >= 0.5) * value;
p_hl = static_cast<scalar_t>(delta_y >= 0.5 && delta_x < 0.5) * value;
p_hh = static_cast<scalar_t>(delta_y >= 0.5 && delta_x >= 0.5) * value;
}
};
// UTILITY FUNCTIONS AND MACROS
template<typename coord_t>
INLINE_HOST_DEVICE coord_t roi_to_img(coord_t s_roi, coord_t s0_img, coord_t s1_img, coord_t roi_size) {
return s_roi / roi_size * (s1_img - s0_img) + s0_img;
}
template<typename coord_t>
INLINE_HOST_DEVICE coord_t img_to_img(coord_t s, coord_t size_in, coord_t size_out) {
return s / size_in * size_out;
}
#define INTERPOLATION_PADDING_DEFINES(INTERPOLATION, PADDING) \
using indexer_t = Indexer<scalar_t, index_t, PADDING>; \
using interpolator_t = Interpolator<scalar_t, coord_t, INTERPOLATION>; \
using sampler_t = Sampler<scalar_t, coord_t, index_t, indexer_t, interpolator_t>;
#define DISPATCH_INTERPOLATION_PADDING_MODES(INTERPOLATION, PADDING, ...) \
[&] { \
switch (INTERPOLATION) { \
case Interpolation::Bilinear: \
AT_CHECK(!std::is_integral<scalar_t>::value, \
"Bilinear interpolation is not available for integral types"); \
switch (PADDING) { \
case PaddingMode::Zero: { \
INTERPOLATION_PADDING_DEFINES(Interpolation::Bilinear, PaddingMode::Zero) \
return __VA_ARGS__(); \
} \
case PaddingMode::Border: { \
INTERPOLATION_PADDING_DEFINES(Interpolation::Bilinear, PaddingMode::Border)\
return __VA_ARGS__(); \
}} \
case Interpolation::Nearest: \
switch (PADDING) { \
case PaddingMode::Zero: { \
INTERPOLATION_PADDING_DEFINES(Interpolation::Nearest, PaddingMode::Zero) \
return __VA_ARGS__(); \
} \
case PaddingMode::Border: { \
INTERPOLATION_PADDING_DEFINES(Interpolation::Nearest, PaddingMode::Border) \
return __VA_ARGS__(); \
}} \
} \
}()