in src/operator/tensor/ordering_op-inl.h [148:275]
void TopKImpl(RunContext ctx,
Resource resource,
const TBlob& src,
const std::vector<TBlob>& ret,
const TopKParam& param) {
using namespace mshadow;
using namespace mshadow::expr;
for (auto ret_ele : ret) {
CHECK_EQ(ret_ele.type_flag_, src.type_flag_);
}
// 1. Parse and initialize information
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 1, real_t> workspace;
Tensor<xpu, 1, real_t> sorted_dat, indices, batch_id, sel_indices;
Tensor<xpu, 2, real_t> mask_val;
int batch_size, element_num; // number of batches + the size of each batch
int axis = 0;
bool do_transpose = false;
bool is_ascend = false;
int k = 0;
TShape target_shape;
ParseTopKParam(src.shape_, param,
&target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
Tensor<xpu, 3, real_t> dat = src.FlatTo3D<xpu, real_t>(axis, axis, s);
if (param.ret_typ == topk_enum::kReturnMask) {
workspace =
resource.get_space_typed<xpu, 1, real_t>(Shape1(src.Size() * 3 + 2 * batch_size * k), s);
} else {
workspace = resource.get_space_typed<xpu, 1, real_t>(mshadow::Shape1(src.Size() * 3), s);
}
sorted_dat = Tensor<xpu, 1, real_t>(workspace.dptr_,
Shape1(src.Size()), s); // contain sorted dat
indices = Tensor<xpu, 1, real_t>(workspace.dptr_ + src.Size(),
Shape1(src.Size()), s); // indices in the original matrix
batch_id = Tensor<xpu, 1, real_t>(workspace.dptr_ + 2 * src.Size(),
Shape1(src.Size()), s); // batch id in the original matrix
if (do_transpose) {
sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
} else {
sorted_dat = reshape(dat, Shape1(src.Size()));
}
indices = range<real_t>(0, batch_size * element_num);
CHECK_EQ(sorted_dat.CheckContiguous(), true);
CHECK_EQ(indices.CheckContiguous(), true);
if (param.ret_typ == topk_enum::kReturnMask) {
sel_indices = Tensor<xpu, 1, real_t>(workspace.dptr_ + 3 * src.Size(),
Shape1(batch_size * k), s);
mask_val = Tensor<xpu, 2, real_t>(workspace.dptr_ + 3 * src.Size() + batch_size * k,
Shape2(batch_size * k, 1), s);
mask_val = scalar<real_t>(1);
CHECK_EQ(sel_indices.CheckContiguous(), true);
CHECK_EQ(mask_val.CheckContiguous(), true);
}
// 2. Perform inplace batch sort using the `SortByKey` in MShadow
// After sorting, each batch in `sorted_dat` will be sorted in the corresponding order
// and the `indices` will contain the corresponding index in `sorted_dat`
// Sort the data and keep record of the correspondence to global indices.
mxnet::op::SortByKey(sorted_dat, indices, is_ascend);
// Calculate the corresponding batch indices of the elements
batch_id = F<mshadow_op::floor>(indices / static_cast<real_t>(element_num));
// Since the SortByKey performs stable sort, the second SortByKey will reorder
// the sorted_dat based on the order of the batch_id
mxnet::op::SortByKey(batch_id, sorted_dat, true);
// Reorder the indices
batch_id = F<mshadow_op::floor>(indices / static_cast<real_t>(element_num));
mxnet::op::SortByKey(batch_id, indices, true);
// 3. Assign results to the ret blob
if (param.ret_typ == topk_enum::kReturnMask) {
Tensor<xpu, 2, real_t> ret_mask =
ret[0].get_with_shape<xpu, 2, real_t>(Shape2(ret[0].Size(), 1), s);
ret_mask = scalar<real_t>(0);
sel_indices = reshape(slice<1>(
inplace_reshape(indices,
Shape2(batch_size,
element_num)), 0, k),
Shape1(batch_size * k));
if (do_transpose) {
TShape src_shape = src.shape_.FlatTo3D(axis);
CHECK_EQ(sel_indices.CheckContiguous(), true);
sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
Shape3(0, 2, 1));
}
IndexFill(ret_mask, sel_indices, mask_val);
} else if (param.ret_typ == topk_enum::kReturnIndices) {
indices -= batch_id * static_cast<real_t>(element_num);
if (do_transpose) {
Tensor<xpu, 3, real_t> ret_indices = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
ret_indices = transpose(
slice<2>(inplace_reshape(indices,
Shape3(ret_indices.shape_[0],
ret_indices.shape_[2],
element_num)),
0, k),
Shape3(0, 2, 1));
} else {
Tensor<xpu, 2, real_t> ret_indices =
ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
ret_indices = slice<1>(inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k);
}
} else {
indices -= batch_id * static_cast<real_t>(element_num);
if (do_transpose) {
Tensor<xpu, 3, real_t> ret_value = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
Tensor<xpu, 3, real_t> ret_indices = ret[1].FlatTo3D<xpu, real_t>(axis, axis, s);
ret_value = transpose(
slice<2>(inplace_reshape(sorted_dat,
Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)),
0, k),
Shape3(0, 2, 1));
ret_indices = transpose(
slice<2>(inplace_reshape(indices,
Shape3(ret_indices.shape_[0],
ret_indices.shape_[2],
element_num)),
0, k),
Shape3(0, 2, 1));
} else {
Tensor<xpu, 2, real_t> ret_value =
ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
Tensor<xpu, 2, real_t> ret_indices =
ret[1].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
ret_value = slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k);
ret_indices = slice<1>(inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k);
}
}
}