in rela/r2d2_actor.h [103:172]
virtual void postAct(const torch::Tensor& r, const torch::Tensor& t) {
if (replayBuffer_ == nullptr) {
return;
}
// assert(replayBuffer_ != nullptr);
multiStepBuffer_->pushRewardAndTerminal(r, t);
// if ith state is terminal, reset hidden states
// h0: [num_layers * num_directions, batch, hidden_size]
TensorDict h0 = getH0(1, numPlayer_);
auto terminal = t.accessor<bool, 1>();
// std::cout << "terminal size: " << t.sizes() << std::endl;
// std::cout << "hid size: " << hidden_["h0"].sizes() << std::endl;
for (int i = 0; i < terminal.size(0); i++) {
if (!terminal[i]) {
continue;
}
for (auto& kv : hidden_) {
// [numLayer, numEnvs, hidDim]
// [numLayer, numEnvs, numPlayer (>1), hidDim]
kv.second.narrow(1, i * numPlayer_, numPlayer_) = h0.at(kv.first);
}
}
if (replayBuffer_ == nullptr) {
return;
}
assert(multiStepBuffer_->size() == historyHidden_.size());
if (!multiStepBuffer_->canPop()) {
assert(!r2d2Buffer_->canPop());
return;
}
{
FFTransition transition = multiStepBuffer_->popTransition();
TensorDict hid = historyHidden_.front();
TensorDict nextHid = historyHidden_.back();
historyHidden_.pop_front();
auto input = transition.toDict();
for (auto& kv : hid) {
auto ret = input.emplace(kv.first, kv.second.transpose(0, 1));
assert(ret.second);
}
for (auto& kv : nextHid) {
auto ret = input.emplace("next_" + kv.first, kv.second.transpose(0, 1));
assert(ret.second);
}
int slot = -1;
auto futureReply = runner_->call("compute_priority", input, &slot);
auto priority = futureReply->get(slot)["priority"];
r2d2Buffer_->push(transition, priority, hid);
}
if (!r2d2Buffer_->canPop()) {
return;
}
std::vector<RNNTransition> batch;
torch::Tensor seqBatchPriority;
torch::Tensor batchLen;
std::tie(batch, seqBatchPriority, batchLen) = r2d2Buffer_->popTransition();
auto priority = aggregatePriority(seqBatchPriority, batchLen, eta_);
replayBuffer_->add(batch, priority);
}