in torch_xla/csrc/nms_op.cpp [151:296]
NmsResult BuildNms(xla::XlaOp boxes, xla::XlaOp scores,
xla::XlaOp score_threshold, xla::XlaOp iou_threshold,
xla::int64_t output_size) {
const xla::Shape& boxes_shape = XlaHelpers::ShapeOfXlaOp(boxes);
xla::int64_t num_boxes = boxes_shape.dimensions(0);
const xla::Shape& scores_shape = XlaHelpers::ShapeOfXlaOp(scores);
XLA_CHECK_EQ(boxes_shape.rank(), 2);
XLA_CHECK_EQ(boxes_shape.dimensions(1), 4);
XLA_CHECK_EQ(scores_shape.rank(), 1);
XLA_CHECK_EQ(scores_shape.dimensions(0), num_boxes);
XLA_CHECK_LT(num_boxes, std::numeric_limits<xla::int32>::max());
XLA_CHECK_GE(output_size, 0);
XLA_CHECK_LT(output_size, std::numeric_limits<xla::int32>::max());
xla::XlaBuilder* builder = boxes.builder();
// Choose a more convenient layout.
xla::XlaOp boxes_transposed = xla::Transpose(boxes, {1, 0});
xla::XlaOp boxes_sorted = xla::GetTupleElement(
xla::Sort({xla::Broadcast(scores, {4}), boxes_transposed},
xla::CreateScalarGtComputation(
{scores_shape.element_type(), boxes_shape.element_type()},
builder),
/*dimension=*/1),
1);
// Track the mapping of indices into sorted domain.
xla::XlaOp iota_indices =
xla::Iota(builder, xla::PrimitiveType::S32, num_boxes);
xla::XlaOp indices_sort = xla::Sort(
{scores, iota_indices},
xla::CreateScalarGtComputation(
{scores_shape.element_type(), xla::PrimitiveType::S32}, builder));
xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1);
xla::XlaOp scores_sorted = xla::GetTupleElement(indices_sort, 0);
// Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0.
xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
/*start_index=*/0,
/*limit_index=*/1,
/*stride=*/1,
/*dimno=*/0),
{num_boxes});
xla::XlaOp c_x0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
/*start_index=*/1,
/*limit_index=*/2,
/*stride=*/1,
/*dimno=*/0),
{num_boxes});
xla::XlaOp c_y1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
/*start_index=*/2,
/*limit_index=*/3,
/*stride=*/1,
/*dimno=*/0),
{num_boxes});
xla::XlaOp c_x1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
/*start_index=*/3,
/*limit_index=*/4,
/*stride=*/1,
/*dimno=*/0),
{num_boxes});
xla::XlaOp y1 = xla::Select(xla::Le(c_y0, c_y1), c_y0, c_y1);
xla::XlaOp y2 = xla::Select(xla::Le(c_y0, c_y1), c_y1, c_y0);
xla::XlaOp x1 = xla::Select(xla::Le(c_x0, c_x1), c_x0, c_x1);
xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0);
xla::XlaOp area = (y2 - y1) * (x2 - x1);
// Shapes are henceforth [1, num_boxes].
y1 = xla::Broadcast(y1, {1});
y2 = xla::Broadcast(y2, {1});
x1 = xla::Broadcast(x1, {1});
x2 = xla::Broadcast(x2, {1});
area = xla::Broadcast(area, {1});
// Shapes are henceforth [num_boxes, num_boxes].
xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0}));
xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0}));
xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0}));
xla::XlaOp i_ymax = xla::Min(y2, xla::Transpose(y2, {1, 0}));
auto square_zero = xla::ZerosLike(i_xmin);
xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) *
xla::Max(i_ymax - i_ymin, square_zero);
xla::XlaOp u_area = area + xla::Transpose(area, {1, 0}) - i_area;
xla::XlaOp iou = i_area / u_area;
xla::XlaOp iou_threshold_mask = xla::Gt(iou, iou_threshold + square_zero);
xla::XlaOp included_iou =
xla::Broadcast(xla::ConstantR0<bool>(builder, true), {num_boxes});
if (boxes_shape.is_dynamic_dimension(0)) {
// Update included_iou's size to match boxes actual size.
included_iou = xla::SetDimensionSize(
included_iou, XlaHelpers::GetDimensionsSize({boxes}, {0}).size, 0);
}
xla::XlaOp zero_s32 = xla::Zero(builder, xla::PrimitiveType::S32);
xla::XlaOp one_s32 = xla::One(builder, xla::PrimitiveType::S32);
std::vector<xla::XlaOp> init_values;
init_values.reserve(4);
init_values.push_back(zero_s32); // col_idx
init_values.push_back(zero_s32); // num_outputs
init_values.push_back(iou_threshold_mask);
init_values.push_back(included_iou);
auto suppress_loop_result = ConsumeValue(xla::WhileLoopHelper(
WhileCondFn(num_boxes, output_size), SuppressBodyFn(num_boxes),
init_values, "BoxSuppressLoop", builder));
xla::XlaOp included_score =
xla::Gt(scores_sorted, xla::Broadcast(score_threshold, {num_boxes}));
xla::XlaOp included = xla::And(included_score, suppress_loop_result[3]);
// Only consider boxes over which we have iterated. This allows for accurate
// counting. DynamicSlice would require knowledge of the size of the output.
xla::XlaOp valid_elem = xla::Lt(
iota_indices, xla::Broadcast(suppress_loop_result[0], {num_boxes}));
included = xla::And(included, valid_elem);
xla::XlaOp neg_inf = xla::Broadcast(
xla::MinValue(builder, scores_shape.element_type()), {num_boxes});
xla::XlaOp scores_included = xla::Select(included, scores_sorted, neg_inf);
xla::XlaOp output_tuple = xla::TopK(scores_included, output_size);
xla::XlaOp selected_indices_sorted = xla::GetTupleElement(output_tuple, 1);
// Calculate num_valid.
// Note: num_valid cannot be taken from the loop outputs, because outputs
// can be suppressed by score threshold.
xla::XlaOp ones_included =
xla::Select(included, xla::Broadcast(one_s32, {num_boxes}),
xla::Broadcast(zero_s32, {num_boxes}));
// num_valid is scalar. Value should be bound by output_size.
xla::XlaOp num_valid_total = xla::Reduce(
ones_included,
/*init_value=*/zero_s32,
/*computation=*/
xla::CreateScalarAddComputation(xla::PrimitiveType::S32, builder),
/*dimensions_to_reduce=*/{0});
xla::XlaOp num_valid = xla::Min(
num_valid_total, xla::ConstantR0<xla::int32>(builder, output_size));
// Re-index into the original scores input tensor, using a Gather.
// Boxes were suppressed in the sorted domain.
xla::XlaOp selected_indices =
NmsGather(indices_sorted, scores_shape.dimensions(),
selected_indices_sorted, {output_size},
/*axis=*/0);
return {selected_indices, num_valid};
}