in include/tvm/topi/transform.h [1096:1248]
inline Tensor take(const Tensor& a, Variant<Tensor, PrimExpr> indices, int batch_dims, int axis,
std::string mode = "clip", std::string name = "T_take",
std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(a->shape.size());
}
ICHECK_GE(axis, 0) << "axis out of bounds";
ICHECK_LT(axis, a->shape.size()) << "axis out of bounds";
auto axis_dim = a->shape[axis];
auto indices_shape = [&]() -> Array<PrimExpr> {
if (auto tensor = indices.as<TensorNode>()) {
return tensor->shape;
} else {
return {};
}
}();
int indices_len = static_cast<int>(indices_shape.size());
int batch_dims_ = batch_dims;
if (batch_dims_ != 0) {
ICHECK_GE(batch_dims_, -indices_len) << "batch_dims out of bounds";
ICHECK_LE(batch_dims_, indices_len) << "batch_dims out of bounds";
if (batch_dims_ < 0) {
batch_dims_ = indices_len + batch_dims_;
}
ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis";
for (int i = 0; i < batch_dims_; ++i) {
auto addr1 = a->shape[i];
auto addr2 = indices_shape[i];
auto v1 = static_cast<IntImm*>(&addr1)->get()->value;
auto v2 = static_cast<IntImm*>(&addr2)->get()->value;
ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]";
}
}
// The result shape is a.shape[:axis] + indices.shape[batch_dims:] +
// a.shape[axis + 1:].
Array<PrimExpr> out_shape;
for (int i = 0; i < batch_dims_; ++i) {
out_shape.push_back(a->shape[i]);
}
for (int i = batch_dims_; i < axis; ++i) {
out_shape.push_back(a->shape[i]);
}
for (int i = batch_dims_; i < indices_len; ++i) {
out_shape.push_back(indices_shape[i]);
}
for (size_t i = axis + 1; i < a->shape.size(); ++i) {
out_shape.push_back(a->shape[i]);
}
auto get_index = [&](const Array<PrimExpr>& indices_position) -> PrimExpr {
if (auto tensor = indices.as<Tensor>()) {
return tensor.value()(indices_position);
} else if (auto prim = indices.as<PrimExpr>()) {
ICHECK_EQ(indices_position.size(), 0);
return prim.value();
} else {
LOG(FATAL) << "Variant did not contain either allowed type";
}
};
if (mode == "clip") {
if (batch_dims_ == 0) {
return compute(
out_shape,
[&](const Array<Var>& out_index) {
Array<PrimExpr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = tvm::min(tvm::max(0, get_index(indices_position)), axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
},
name, tag);
} else {
return compute(
out_shape,
[&](const Array<Var>& out_index) {
Array<PrimExpr> indices_position;
for (size_t j = 0; j < static_cast<size_t>(batch_dims_); ++j) {
indices_position.push_back(out_index[j]);
}
for (size_t j = axis; j < static_cast<size_t>(axis + indices_len - batch_dims_); ++j) {
indices_position.push_back(out_index[j]);
}
Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = tvm::min(tvm::max(0, get_index(indices_position)), axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
},
name, tag);
}
} else if (mode == "fast") {
return compute(
out_shape,
[&](const Array<Var>& out_index) {
Array<PrimExpr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
real_indices.push_back(get_index(indices_position));
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
},
name, tag);
} else { // mode == "wrap"
return compute(
out_shape,
[&](const Array<Var>& out_index) {
Array<PrimExpr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = truncmod(truncmod(get_index(indices_position), axis_dim) + axis_dim, axis_dim);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
},
name, tag);
}
}