NmsResult BuildNms()

in torch_xla/csrc/nms_op.cpp [151:296]


NmsResult BuildNms(xla::XlaOp boxes, xla::XlaOp scores,
                   xla::XlaOp score_threshold, xla::XlaOp iou_threshold,
                   xla::int64_t output_size) {
  const xla::Shape& boxes_shape = XlaHelpers::ShapeOfXlaOp(boxes);
  xla::int64_t num_boxes = boxes_shape.dimensions(0);
  const xla::Shape& scores_shape = XlaHelpers::ShapeOfXlaOp(scores);
  XLA_CHECK_EQ(boxes_shape.rank(), 2);
  XLA_CHECK_EQ(boxes_shape.dimensions(1), 4);
  XLA_CHECK_EQ(scores_shape.rank(), 1);
  XLA_CHECK_EQ(scores_shape.dimensions(0), num_boxes);
  XLA_CHECK_LT(num_boxes, std::numeric_limits<xla::int32>::max());
  XLA_CHECK_GE(output_size, 0);
  XLA_CHECK_LT(output_size, std::numeric_limits<xla::int32>::max());

  xla::XlaBuilder* builder = boxes.builder();
  // Choose a more convenient layout.
  xla::XlaOp boxes_transposed = xla::Transpose(boxes, {1, 0});
  xla::XlaOp boxes_sorted = xla::GetTupleElement(
      xla::Sort({xla::Broadcast(scores, {4}), boxes_transposed},
                xla::CreateScalarGtComputation(
                    {scores_shape.element_type(), boxes_shape.element_type()},
                    builder),
                /*dimension=*/1),
      1);
  // Track the mapping of indices into sorted domain.
  xla::XlaOp iota_indices =
      xla::Iota(builder, xla::PrimitiveType::S32, num_boxes);
  xla::XlaOp indices_sort = xla::Sort(
      {scores, iota_indices},
      xla::CreateScalarGtComputation(
          {scores_shape.element_type(), xla::PrimitiveType::S32}, builder));
  xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1);
  xla::XlaOp scores_sorted = xla::GetTupleElement(indices_sort, 0);

  // Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0.
  xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
                                                 /*start_index=*/0,
                                                 /*limit_index=*/1,
                                                 /*stride=*/1,
                                                 /*dimno=*/0),
                                 {num_boxes});
  xla::XlaOp c_x0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
                                                 /*start_index=*/1,
                                                 /*limit_index=*/2,
                                                 /*stride=*/1,
                                                 /*dimno=*/0),
                                 {num_boxes});
  xla::XlaOp c_y1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
                                                 /*start_index=*/2,
                                                 /*limit_index=*/3,
                                                 /*stride=*/1,
                                                 /*dimno=*/0),
                                 {num_boxes});
  xla::XlaOp c_x1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
                                                 /*start_index=*/3,
                                                 /*limit_index=*/4,
                                                 /*stride=*/1,
                                                 /*dimno=*/0),
                                 {num_boxes});

  xla::XlaOp y1 = xla::Select(xla::Le(c_y0, c_y1), c_y0, c_y1);
  xla::XlaOp y2 = xla::Select(xla::Le(c_y0, c_y1), c_y1, c_y0);
  xla::XlaOp x1 = xla::Select(xla::Le(c_x0, c_x1), c_x0, c_x1);
  xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0);
  xla::XlaOp area = (y2 - y1) * (x2 - x1);

  // Shapes are henceforth [1, num_boxes].
  y1 = xla::Broadcast(y1, {1});
  y2 = xla::Broadcast(y2, {1});
  x1 = xla::Broadcast(x1, {1});
  x2 = xla::Broadcast(x2, {1});
  area = xla::Broadcast(area, {1});

  // Shapes are henceforth [num_boxes, num_boxes].
  xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0}));
  xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0}));
  xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0}));
  xla::XlaOp i_ymax = xla::Min(y2, xla::Transpose(y2, {1, 0}));
  auto square_zero = xla::ZerosLike(i_xmin);

  xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) *
                      xla::Max(i_ymax - i_ymin, square_zero);
  xla::XlaOp u_area = area + xla::Transpose(area, {1, 0}) - i_area;
  xla::XlaOp iou = i_area / u_area;

  xla::XlaOp iou_threshold_mask = xla::Gt(iou, iou_threshold + square_zero);
  xla::XlaOp included_iou =
      xla::Broadcast(xla::ConstantR0<bool>(builder, true), {num_boxes});
  if (boxes_shape.is_dynamic_dimension(0)) {
    // Update included_iou's size to match boxes actual size.
    included_iou = xla::SetDimensionSize(
        included_iou, XlaHelpers::GetDimensionsSize({boxes}, {0}).size, 0);
  }

  xla::XlaOp zero_s32 = xla::Zero(builder, xla::PrimitiveType::S32);
  xla::XlaOp one_s32 = xla::One(builder, xla::PrimitiveType::S32);
  std::vector<xla::XlaOp> init_values;
  init_values.reserve(4);
  init_values.push_back(zero_s32);  // col_idx
  init_values.push_back(zero_s32);  // num_outputs
  init_values.push_back(iou_threshold_mask);
  init_values.push_back(included_iou);

  auto suppress_loop_result = ConsumeValue(xla::WhileLoopHelper(
      WhileCondFn(num_boxes, output_size), SuppressBodyFn(num_boxes),
      init_values, "BoxSuppressLoop", builder));

  xla::XlaOp included_score =
      xla::Gt(scores_sorted, xla::Broadcast(score_threshold, {num_boxes}));
  xla::XlaOp included = xla::And(included_score, suppress_loop_result[3]);

  // Only consider boxes over which we have iterated. This allows for accurate
  // counting. DynamicSlice would require knowledge of the size of the output.
  xla::XlaOp valid_elem = xla::Lt(
      iota_indices, xla::Broadcast(suppress_loop_result[0], {num_boxes}));
  included = xla::And(included, valid_elem);

  xla::XlaOp neg_inf = xla::Broadcast(
      xla::MinValue(builder, scores_shape.element_type()), {num_boxes});
  xla::XlaOp scores_included = xla::Select(included, scores_sorted, neg_inf);
  xla::XlaOp output_tuple = xla::TopK(scores_included, output_size);
  xla::XlaOp selected_indices_sorted = xla::GetTupleElement(output_tuple, 1);
  // Calculate num_valid.
  // Note: num_valid cannot be taken from the loop outputs, because outputs
  // can be suppressed by score threshold.
  xla::XlaOp ones_included =
      xla::Select(included, xla::Broadcast(one_s32, {num_boxes}),
                  xla::Broadcast(zero_s32, {num_boxes}));
  // num_valid is scalar. Value should be bound by output_size.
  xla::XlaOp num_valid_total = xla::Reduce(
      ones_included,
      /*init_value=*/zero_s32,
      /*computation=*/
      xla::CreateScalarAddComputation(xla::PrimitiveType::S32, builder),
      /*dimensions_to_reduce=*/{0});
  xla::XlaOp num_valid = xla::Min(
      num_valid_total, xla::ConstantR0<xla::int32>(builder, output_size));

  // Re-index into the original scores input tensor, using a Gather.
  // Boxes were suppressed in the sorted domain.
  xla::XlaOp selected_indices =
      NmsGather(indices_sorted, scores_shape.dimensions(),
                selected_indices_sorted, {output_size},
                /*axis=*/0);
  return {selected_indices, num_valid};
}