void visit()

in src/moolib.cc [671:742]


  void visit(const py::handle& dest, const py::handle& source) {
    if (py::isinstance<py::dict>(dest)) {
      if (!py::isinstance<py::dict>(source)) {
        throw std::runtime_error("type mismatch in batch operation");
      }
      const py::dict& sourceDict = py::reinterpret_borrow<py::dict>(source);
      const py::dict& destDict = py::reinterpret_borrow<py::dict>(dest);
      for (auto& [key, value] : destDict) {
        visit<cat>(value, sourceDict[key]);
      }
    } else if (py::isinstance<py::list>(dest)) {
      if (!py::isinstance<py::list>(source)) {
        throw std::runtime_error("type mismatch in batch operation");
      }
      const py::list& sourceList = py::reinterpret_borrow<py::list>(source);
      const py::list& destList = py::reinterpret_borrow<py::list>(dest);
      size_t n = destList.size();
      for (size_t i = 0; i != n; ++i) {
        visit<cat>(destList[i], sourceList[i]);
      }
    } else if (auto destT = rpc::tryFromPython(dest)) {
      auto sourceT = rpc::tryFromPython(source);
      if (!sourceT) {
        throw std::runtime_error("type mismatch in batch operation");
      }
      auto s = sourceT->sizes();
      if ((int64_t)s.size() <= (cat ? batchDimension : batchDimension - 1)) {
        throw std::runtime_error(fmt::sprintf(
            "Given input tensor with %d dimensions, cannot %s in dimension %d", s.size(), cat ? "cat" : "stack",
            batchDimension));
      }
      if (cat) {
        int64_t inputOffset = catBatchInputOffset;
        int64_t n = s[batchDimension];
        if (inputOffset > n) {
          fatal("Batch internal error: offset > n");
        }
        if (currentTensor == 0) {
          catBatchInputSize = n;
        } else {
          if (n != catBatchInputSize) {
            throw std::runtime_error(fmt::sprintf(
                "Batch dimension size mismatch; during a cat operation, all tensors must have the same size in the "
                "batch dimension (%d). Got %d and %d",
                batchDimension, catBatchInputSize, n));
          }
        }
        int64_t outputOffset = catBatchOutputOffset;
        int64_t left = batchSize - outputOffset;
        n -= inputOffset;
        if (n <= left && inputOffset == 0) {
          destT->narrow(batchDimension, outputOffset, n).copy_(*sourceT);
        } else {
          n = std::min(n, left);
          destT->narrow(batchDimension, outputOffset, n).copy_(sourceT->narrow(batchDimension, inputOffset, n));
        }
      } else {
        destT->select(batchDimension, nextStackIndex).copy_(*sourceT);
      }
      ++currentTensor;
    } else if (py::isinstance<py::tuple>(dest)) {
      if (!py::isinstance<py::tuple>(source)) {
        throw std::runtime_error("type mismatch in batch operation");
      }
      const py::tuple& sourceTuple = py::reinterpret_borrow<py::tuple>(source);
      const py::tuple& destTuple = py::reinterpret_borrow<py::tuple>(dest);
      size_t n = destTuple.size();
      for (size_t i = 0; i != n; ++i) {
        visit<cat>(destTuple[i], sourceTuple[i]);
      }
    }
  }