void TopKImpl()

in src/operator/tensor/ordering_op-inl.h [148:275]


void TopKImpl(RunContext ctx,
              Resource resource,
              const TBlob& src,
              const std::vector<TBlob>& ret,
              const TopKParam& param) {
  using namespace mshadow;
  using namespace mshadow::expr;
  for (auto ret_ele : ret) {
    CHECK_EQ(ret_ele.type_flag_, src.type_flag_);
  }
  // 1. Parse and initialize information
  Stream<xpu> *s = ctx.get_stream<xpu>();
  Tensor<xpu, 1, real_t> workspace;
  Tensor<xpu, 1, real_t> sorted_dat, indices, batch_id, sel_indices;
  Tensor<xpu, 2, real_t> mask_val;
  int batch_size, element_num;  // number of batches + the size of each batch
  int axis = 0;
  bool do_transpose = false;
  bool is_ascend = false;
  int k = 0;
  TShape target_shape;
  ParseTopKParam(src.shape_, param,
                 &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
  Tensor<xpu, 3, real_t> dat = src.FlatTo3D<xpu, real_t>(axis, axis, s);
  if (param.ret_typ == topk_enum::kReturnMask) {
    workspace =
      resource.get_space_typed<xpu, 1, real_t>(Shape1(src.Size() * 3 + 2 * batch_size * k), s);
  } else {
    workspace = resource.get_space_typed<xpu, 1, real_t>(mshadow::Shape1(src.Size() * 3), s);
  }
  sorted_dat = Tensor<xpu, 1, real_t>(workspace.dptr_,
                                      Shape1(src.Size()), s);  // contain sorted dat
  indices = Tensor<xpu, 1, real_t>(workspace.dptr_ + src.Size(),
                                   Shape1(src.Size()), s);  // indices in the original matrix
  batch_id = Tensor<xpu, 1, real_t>(workspace.dptr_ + 2 * src.Size(),
                                    Shape1(src.Size()), s);  // batch id in the original matrix
  if (do_transpose) {
    sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
  } else {
    sorted_dat = reshape(dat, Shape1(src.Size()));
  }
  indices = range<real_t>(0, batch_size * element_num);
  CHECK_EQ(sorted_dat.CheckContiguous(), true);
  CHECK_EQ(indices.CheckContiguous(), true);
  if (param.ret_typ == topk_enum::kReturnMask) {
    sel_indices = Tensor<xpu, 1, real_t>(workspace.dptr_ + 3 * src.Size(),
                                         Shape1(batch_size * k), s);
    mask_val = Tensor<xpu, 2, real_t>(workspace.dptr_ + 3 * src.Size() + batch_size * k,
                                      Shape2(batch_size * k, 1), s);
    mask_val = scalar<real_t>(1);
    CHECK_EQ(sel_indices.CheckContiguous(), true);
    CHECK_EQ(mask_val.CheckContiguous(), true);
  }

  // 2. Perform inplace batch sort using the `SortByKey` in MShadow
  // After sorting, each batch in `sorted_dat` will be sorted in the corresponding order
  //   and the `indices` will contain the corresponding index in `sorted_dat`
  // Sort the data and keep record of the correspondence to global indices.
  mxnet::op::SortByKey(sorted_dat, indices, is_ascend);
  // Calculate the corresponding batch indices of the elements
  batch_id = F<mshadow_op::floor>(indices / static_cast<real_t>(element_num));
  // Since the SortByKey performs stable sort, the second SortByKey will reorder
  //   the sorted_dat based on the order of the batch_id
  mxnet::op::SortByKey(batch_id, sorted_dat, true);
  // Reorder the indices
  batch_id = F<mshadow_op::floor>(indices / static_cast<real_t>(element_num));
  mxnet::op::SortByKey(batch_id, indices, true);

  // 3. Assign results to the ret blob
  if (param.ret_typ == topk_enum::kReturnMask) {
    Tensor<xpu, 2, real_t> ret_mask =
      ret[0].get_with_shape<xpu, 2, real_t>(Shape2(ret[0].Size(), 1), s);
    ret_mask = scalar<real_t>(0);
    sel_indices = reshape(slice<1>(
                              inplace_reshape(indices,
                                              Shape2(batch_size,
                                                    element_num)), 0, k),
                            Shape1(batch_size * k));
    if (do_transpose) {
      TShape src_shape = src.shape_.FlatTo3D(axis);
      CHECK_EQ(sel_indices.CheckContiguous(), true);
      sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
                                      Shape3(0, 2, 1));
    }
    IndexFill(ret_mask, sel_indices, mask_val);
  } else if (param.ret_typ == topk_enum::kReturnIndices) {
    indices -= batch_id * static_cast<real_t>(element_num);
    if (do_transpose) {
      Tensor<xpu, 3, real_t> ret_indices = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
      ret_indices = transpose(
                      slice<2>(inplace_reshape(indices,
                                               Shape3(ret_indices.shape_[0],
                                                      ret_indices.shape_[2],
                                                      element_num)),
                               0, k),
                      Shape3(0, 2, 1));
    } else {
      Tensor<xpu, 2, real_t> ret_indices =
        ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
      ret_indices = slice<1>(inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k);
    }
  } else {
    indices -= batch_id * static_cast<real_t>(element_num);
    if (do_transpose) {
      Tensor<xpu, 3, real_t> ret_value = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
      Tensor<xpu, 3, real_t> ret_indices = ret[1].FlatTo3D<xpu, real_t>(axis, axis, s);
      ret_value = transpose(
                   slice<2>(inplace_reshape(sorted_dat,
                                    Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)),
                            0, k),
                   Shape3(0, 2, 1));
      ret_indices = transpose(
                      slice<2>(inplace_reshape(indices,
                                               Shape3(ret_indices.shape_[0],
                                                      ret_indices.shape_[2],
                                                      element_num)),
                               0, k),
                      Shape3(0, 2, 1));
    } else {
      Tensor<xpu, 2, real_t> ret_value =
        ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
      Tensor<xpu, 2, real_t> ret_indices =
        ret[1].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
      ret_value = slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k);
      ret_indices = slice<1>(inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k);
    }
  }
}