in src/env.h [254:328]
void step(Shared* shared, size_t bufferIndex, size_t batchIndex) {
++steps;
if (prevActions_.size() <= bufferIndex) {
prevActions_.resize(bufferIndex + 1);
}
auto& prevActions = prevActions_[bufferIndex];
if (prevActions.size() <= batchIndex) {
prevActions.resize(batchIndex + 1);
}
uint32_t prevAction = prevActions[batchIndex];
uint32_t action = prevAction;
auto& sa = shared->buffers[bufferIndex].envInputs[batchIndex].action;
uint32_t timeCheckCounter = 0x100000;
std::optional<std::chrono::steady_clock::time_point> waitTime;
do {
action = sa.load();
if (terminate_) {
return;
}
if (--timeCheckCounter == 0) {
timeCheckCounter = 0x100000;
if (!waitTime) {
waitTime = std::chrono::steady_clock::now();
} else if (std::chrono::steady_clock::now() - *waitTime >= std::chrono::seconds(120)) {
throw std::runtime_error("Timed out waiting for env action");
}
}
} while (action == prevAction);
prevActions[batchIndex] = action;
action -= prevAction + 1;
{
py::gil_scoped_acquire gil;
bool done;
float reward;
py::dict obs;
py::object rawObs;
if (steps == 1) {
done = false;
reward = 0.0f;
rawObs = reset_();
} else {
py::tuple tup = step_(action);
rawObs = tup[0];
reward = (py::float_)tup[1];
done = (py::bool_)tup[2];
if (done) {
// log.debug("episode done after %d steps with %g total reward\n", episodeStep, episodeReturn);
rawObs = reset_();
}
}
if (py::isinstance<py::dict>(rawObs)) {
obs = rawObs.cast<py::dict>();
} else {
obs["state"] = rawObs.cast<py::array>();
}
auto& buffer = shared->buffers[bufferIndex];
auto& batch = buffer.batchData;
if (!buffer.batchAllocated.load()) {
if (buffer.batchAllocating.exchange(true)) {
while (!buffer.batchAllocated)
;
} else {
allocateBatch(shared, batch, obs);
buffer.batchAllocated = true;
}
}
fillBatch(shared, batch, batchIndex, "done", &done, sizeof(bool));
fillBatch(shared, batch, batchIndex, "reward", &reward, sizeof(float));
for (auto& [key, value] : obs) {
auto [str, stro] = rpc::pyStrView(key);
py::array arr(py::reinterpret_borrow<py::object>(value));
fillBatch(shared, batch, batchIndex, str, (float*)arr.data(), arr.nbytes());
}
}
}