in src/moolib.cc [596:668]
py::object prepareForBatchCopy(const py::handle& v) {
if (py::isinstance<py::dict>(v)) {
const py::dict& dict = py::reinterpret_borrow<py::dict>(v);
py::dict newdict;
for (auto& [key, value] : dict) {
newdict[key] = prepareForBatchCopy<cat>(value);
}
return std::move(newdict);
} else if (py::isinstance<py::list>(v)) {
const py::list& list = py::reinterpret_borrow<py::list>(v);
size_t n = list.size();
py::list newlist(n);
for (size_t i = 0; i != n; ++i) {
newlist[i] = prepareForBatchCopy<cat>(list[i]);
}
return std::move(newlist);
} else if (auto t = rpc::tryFromPython(v)) {
auto s = t->sizes();
if ((int64_t)s.size() <= (cat ? batchDimension : batchDimension - 1)) {
throw std::runtime_error(fmt::sprintf(
"Given input tensor with %d dimensions, cannot %s in dimension %d", s.size(), cat ? "cat" : "stack",
batchDimension));
}
if (cat) {
sizes.assign(s.begin(), s.end());
sizes[batchDimension] = batchSize;
} else {
sizes.resize(1 + s.size());
std::copy(s.begin(), s.begin() + batchDimension, sizes.begin());
std::copy(s.begin() + batchDimension, s.end(), sizes.begin() + batchDimension + 1);
sizes[batchDimension] = batchSize;
}
rpc::Tensor tensor = rpc::empty(rpc::IntArrayRef(sizes.data(), sizes.size()), t->scalar_type(), device);
if (cat) {
int64_t offset = catBatchInputOffset;
int64_t n = s[batchDimension];
if (offset > n) {
fatal("Batch internal error: offset > n");
}
if (nTensors == 0) {
catBatchInputSize = n;
} else {
if (n != catBatchInputSize) {
throw std::runtime_error(fmt::sprintf(
"Batch dimension size mismatch; during a cat operation, all tensors must have the same size in the "
"batch dimension (%d). Got %d and %d",
batchDimension, catBatchInputSize, n));
}
}
n -= offset;
if (n <= batchSize && offset == 0) {
tensor.narrow(batchDimension, 0, n).copy_(*t);
} else {
n = std::min(n, batchSize);
tensor.narrow(batchDimension, 0, n).copy_(t->narrow(batchDimension, offset, n));
}
} else {
tensor.select(batchDimension, 0).copy_(*t);
}
++nTensors;
return rpc::toPython(tensor);
} else if (py::isinstance<py::tuple>(v)) {
const py::tuple& tuple = py::reinterpret_borrow<py::tuple>(v);
size_t n = tuple.size();
py::tuple newtuple(n);
for (size_t i = 0; i != n; ++i) {
newtuple[i] = prepareForBatchCopy<cat>(tuple[i]);
}
return std::move(newtuple);
} else {
return py::reinterpret_borrow<py::object>(v);
}
}