EnvStepperFuture EnvStepper::step()

in src/env.cc [273:349]


EnvStepperFuture EnvStepper::step(int bufferIndex, py::object actionObject) {
  auto actionOpt = rpc::tryFromPython(actionObject);
  if (!actionOpt) {
    throw std::runtime_error(
        "EnvStepper::step function was passed an action argument that could not be converted to a Tensor");
  }
  auto& action = *actionOpt;

  if (action.itemsize() != sizeof(long) ||
      (action.scalar_type() != rpc::Tensor::kInt32 && action.scalar_type() != rpc::Tensor::kInt64)) {
    throw std::runtime_error("EnvStepper::step expected action tensor with data type long");
  }
  if (action.dim() != 1) {
    throw std::runtime_error("EnvStepper::step expected a 1-dimensional tensor");
  }

  if (bufferIndex < 0 || (size_t)bufferIndex >= bufferBusy.size()) {
    throw std::runtime_error(fmt::sprintf("EnvStepper: buffer index (%d) out of range", bufferIndex));
  }

  if (bufferBusy[bufferIndex].exchange(true, std::memory_order_relaxed)) {
    throw std::runtime_error(
        fmt::sprintf("EnvStepper: attempt to step buffer index %d twice concurrently", bufferIndex));
  }

  auto& buffer = shared->buffers[bufferIndex];

  std::optional<rpc::CUDAStream> stream;
  if (action.is_cuda()) {
    stream.emplace(rpc::getCurrentCUDAStream());
  }

  size_t size = action.size(0);
  size_t strideDivisor = numClients_;
  size_t stride = (size + strideDivisor - 1) / strideDivisor;

  async.run([this, action = std::move(action), bufferIndex, size, stride, stream = std::move(stream)]() mutable {
    if (stream) {
      rpc::CUDAStreamGuard sg(*stream);
      auto& pinned = actionPinned[bufferIndex];
      if (!pinned.defined()) {
        pinned = action.cpu().pin_memory();
      }
      pinned.copy_(action, true);
      action = pinned;
      stream->synchronize();
    }

    auto& buffer = shared->buffers[bufferIndex];

    size_t clientIndex = 0;
    for (size_t i = 0; i < size; i += stride, ++clientIndex) {
      int nSteps = std::min(size - i, stride);
      auto& input = buffer.clientInputs[clientIndex];
      input.resultOffset.store(i);
      input.nStepsIn.fetch_add(nSteps);

      auto& in = shared->clientIn[clientIndex];
      if (in.queue.size() >= in.queue.capacity()) {
        fatal("EnvStepper: shared queue is full");
      }
      in.queue.push(bufferIndex);
      in.semaphore.post();
    }

    bufferStarted[bufferIndex] = true;

    auto acc = action.data<long>();

    for (size_t i = 0; i != size; ++i) {
      auto& action = buffer.envInputs[i].action;
      action.store(action.load(std::memory_order_relaxed) + 1 + acc[i], std::memory_order_relaxed);
    }
  });

  return {this, bufferIndex, size, stride, &buffer};
}