in src/env.cc [273:349]
EnvStepperFuture EnvStepper::step(int bufferIndex, py::object actionObject) {
auto actionOpt = rpc::tryFromPython(actionObject);
if (!actionOpt) {
throw std::runtime_error(
"EnvStepper::step function was passed an action argument that could not be converted to a Tensor");
}
auto& action = *actionOpt;
if (action.itemsize() != sizeof(long) ||
(action.scalar_type() != rpc::Tensor::kInt32 && action.scalar_type() != rpc::Tensor::kInt64)) {
throw std::runtime_error("EnvStepper::step expected action tensor with data type long");
}
if (action.dim() != 1) {
throw std::runtime_error("EnvStepper::step expected a 1-dimensional tensor");
}
if (bufferIndex < 0 || (size_t)bufferIndex >= bufferBusy.size()) {
throw std::runtime_error(fmt::sprintf("EnvStepper: buffer index (%d) out of range", bufferIndex));
}
if (bufferBusy[bufferIndex].exchange(true, std::memory_order_relaxed)) {
throw std::runtime_error(
fmt::sprintf("EnvStepper: attempt to step buffer index %d twice concurrently", bufferIndex));
}
auto& buffer = shared->buffers[bufferIndex];
std::optional<rpc::CUDAStream> stream;
if (action.is_cuda()) {
stream.emplace(rpc::getCurrentCUDAStream());
}
size_t size = action.size(0);
size_t strideDivisor = numClients_;
size_t stride = (size + strideDivisor - 1) / strideDivisor;
async.run([this, action = std::move(action), bufferIndex, size, stride, stream = std::move(stream)]() mutable {
if (stream) {
rpc::CUDAStreamGuard sg(*stream);
auto& pinned = actionPinned[bufferIndex];
if (!pinned.defined()) {
pinned = action.cpu().pin_memory();
}
pinned.copy_(action, true);
action = pinned;
stream->synchronize();
}
auto& buffer = shared->buffers[bufferIndex];
size_t clientIndex = 0;
for (size_t i = 0; i < size; i += stride, ++clientIndex) {
int nSteps = std::min(size - i, stride);
auto& input = buffer.clientInputs[clientIndex];
input.resultOffset.store(i);
input.nStepsIn.fetch_add(nSteps);
auto& in = shared->clientIn[clientIndex];
if (in.queue.size() >= in.queue.capacity()) {
fatal("EnvStepper: shared queue is full");
}
in.queue.push(bufferIndex);
in.semaphore.post();
}
bufferStarted[bufferIndex] = true;
auto acc = action.data<long>();
for (size_t i = 0; i != size; ++i) {
auto& action = buffer.envInputs[i].action;
action.store(action.load(std::memory_order_relaxed) + 1 + acc[i], std::memory_order_relaxed);
}
});
return {this, bufferIndex, size, stride, &buffer};
}