inline Tensor take()

in include/tvm/topi/transform.h [1096:1248]


inline Tensor take(const Tensor& a, Variant<Tensor, PrimExpr> indices, int batch_dims, int axis,
                   std::string mode = "clip", std::string name = "T_take",
                   std::string tag = kInjective) {
  if (axis < 0) {
    axis += static_cast<int>(a->shape.size());
  }
  ICHECK_GE(axis, 0) << "axis out of bounds";
  ICHECK_LT(axis, a->shape.size()) << "axis out of bounds";
  auto axis_dim = a->shape[axis];
  auto indices_shape = [&]() -> Array<PrimExpr> {
    if (auto tensor = indices.as<TensorNode>()) {
      return tensor->shape;
    } else {
      return {};
    }
  }();

  int indices_len = static_cast<int>(indices_shape.size());

  int batch_dims_ = batch_dims;
  if (batch_dims_ != 0) {
    ICHECK_GE(batch_dims_, -indices_len) << "batch_dims out of bounds";
    ICHECK_LE(batch_dims_, indices_len) << "batch_dims out of bounds";

    if (batch_dims_ < 0) {
      batch_dims_ = indices_len + batch_dims_;
    }

    ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
    ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis";
    for (int i = 0; i < batch_dims_; ++i) {
      auto addr1 = a->shape[i];
      auto addr2 = indices_shape[i];
      auto v1 = static_cast<IntImm*>(&addr1)->get()->value;
      auto v2 = static_cast<IntImm*>(&addr2)->get()->value;
      ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]";
    }
  }

  // The result shape is a.shape[:axis] + indices.shape[batch_dims:] +
  // a.shape[axis + 1:].

  Array<PrimExpr> out_shape;
  for (int i = 0; i < batch_dims_; ++i) {
    out_shape.push_back(a->shape[i]);
  }
  for (int i = batch_dims_; i < axis; ++i) {
    out_shape.push_back(a->shape[i]);
  }
  for (int i = batch_dims_; i < indices_len; ++i) {
    out_shape.push_back(indices_shape[i]);
  }
  for (size_t i = axis + 1; i < a->shape.size(); ++i) {
    out_shape.push_back(a->shape[i]);
  }

  auto get_index = [&](const Array<PrimExpr>& indices_position) -> PrimExpr {
    if (auto tensor = indices.as<Tensor>()) {
      return tensor.value()(indices_position);
    } else if (auto prim = indices.as<PrimExpr>()) {
      ICHECK_EQ(indices_position.size(), 0);
      return prim.value();
    } else {
      LOG(FATAL) << "Variant did not contain either allowed type";
    }
  };

  if (mode == "clip") {
    if (batch_dims_ == 0) {
      return compute(
          out_shape,
          [&](const Array<Var>& out_index) {
            Array<PrimExpr> indices_position;
            for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
              indices_position.push_back(out_index[j]);
            }
            Array<PrimExpr> real_indices;
            for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
              real_indices.push_back(out_index[j]);
            }
            auto idx = tvm::min(tvm::max(0, get_index(indices_position)), axis_dim - 1);
            real_indices.push_back(idx);
            for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
              real_indices.push_back(out_index[j]);
            }
            return a(real_indices);
          },
          name, tag);
    } else {
      return compute(
          out_shape,
          [&](const Array<Var>& out_index) {
            Array<PrimExpr> indices_position;
            for (size_t j = 0; j < static_cast<size_t>(batch_dims_); ++j) {
              indices_position.push_back(out_index[j]);
            }
            for (size_t j = axis; j < static_cast<size_t>(axis + indices_len - batch_dims_); ++j) {
              indices_position.push_back(out_index[j]);
            }
            Array<PrimExpr> real_indices;
            for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
              real_indices.push_back(out_index[j]);
            }
            auto idx = tvm::min(tvm::max(0, get_index(indices_position)), axis_dim - 1);
            real_indices.push_back(idx);
            for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) {
              real_indices.push_back(out_index[j]);
            }
            return a(real_indices);
          },
          name, tag);
    }
  } else if (mode == "fast") {
    return compute(
        out_shape,
        [&](const Array<Var>& out_index) {
          Array<PrimExpr> indices_position;
          for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
            indices_position.push_back(out_index[j]);
          }
          Array<PrimExpr> real_indices;
          for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
            real_indices.push_back(out_index[j]);
          }
          real_indices.push_back(get_index(indices_position));
          for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
            real_indices.push_back(out_index[j]);
          }
          return a(real_indices);
        },
        name, tag);
  } else {  // mode == "wrap"
    return compute(
        out_shape,
        [&](const Array<Var>& out_index) {
          Array<PrimExpr> indices_position;
          for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
            indices_position.push_back(out_index[j]);
          }
          Array<PrimExpr> real_indices;
          for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
            real_indices.push_back(out_index[j]);
          }
          auto idx = truncmod(truncmod(get_index(indices_position), axis_dim) + axis_dim, axis_dim);
          real_indices.push_back(idx);
          for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
            real_indices.push_back(out_index[j]);
          }
          return a(real_indices);
        },
        name, tag);
  }
}