in src/moolib.cc [671:742]
void visit(const py::handle& dest, const py::handle& source) {
if (py::isinstance<py::dict>(dest)) {
if (!py::isinstance<py::dict>(source)) {
throw std::runtime_error("type mismatch in batch operation");
}
const py::dict& sourceDict = py::reinterpret_borrow<py::dict>(source);
const py::dict& destDict = py::reinterpret_borrow<py::dict>(dest);
for (auto& [key, value] : destDict) {
visit<cat>(value, sourceDict[key]);
}
} else if (py::isinstance<py::list>(dest)) {
if (!py::isinstance<py::list>(source)) {
throw std::runtime_error("type mismatch in batch operation");
}
const py::list& sourceList = py::reinterpret_borrow<py::list>(source);
const py::list& destList = py::reinterpret_borrow<py::list>(dest);
size_t n = destList.size();
for (size_t i = 0; i != n; ++i) {
visit<cat>(destList[i], sourceList[i]);
}
} else if (auto destT = rpc::tryFromPython(dest)) {
auto sourceT = rpc::tryFromPython(source);
if (!sourceT) {
throw std::runtime_error("type mismatch in batch operation");
}
auto s = sourceT->sizes();
if ((int64_t)s.size() <= (cat ? batchDimension : batchDimension - 1)) {
throw std::runtime_error(fmt::sprintf(
"Given input tensor with %d dimensions, cannot %s in dimension %d", s.size(), cat ? "cat" : "stack",
batchDimension));
}
if (cat) {
int64_t inputOffset = catBatchInputOffset;
int64_t n = s[batchDimension];
if (inputOffset > n) {
fatal("Batch internal error: offset > n");
}
if (currentTensor == 0) {
catBatchInputSize = n;
} else {
if (n != catBatchInputSize) {
throw std::runtime_error(fmt::sprintf(
"Batch dimension size mismatch; during a cat operation, all tensors must have the same size in the "
"batch dimension (%d). Got %d and %d",
batchDimension, catBatchInputSize, n));
}
}
int64_t outputOffset = catBatchOutputOffset;
int64_t left = batchSize - outputOffset;
n -= inputOffset;
if (n <= left && inputOffset == 0) {
destT->narrow(batchDimension, outputOffset, n).copy_(*sourceT);
} else {
n = std::min(n, left);
destT->narrow(batchDimension, outputOffset, n).copy_(sourceT->narrow(batchDimension, inputOffset, n));
}
} else {
destT->select(batchDimension, nextStackIndex).copy_(*sourceT);
}
++currentTensor;
} else if (py::isinstance<py::tuple>(dest)) {
if (!py::isinstance<py::tuple>(source)) {
throw std::runtime_error("type mismatch in batch operation");
}
const py::tuple& sourceTuple = py::reinterpret_borrow<py::tuple>(source);
const py::tuple& destTuple = py::reinterpret_borrow<py::tuple>(dest);
size_t n = destTuple.size();
for (size_t i = 0; i != n; ++i) {
visit<cat>(destTuple[i], sourceTuple[i]);
}
}
}