in rela/r2d2_actor.h [221:249]
virtual TensorDict act(TensorDict& obs) override {
torch::NoGradGuard ng;
assert(!hidden_.empty());
if (replayBuffer_ != nullptr) {
historyHidden_.push_back(hidden_);
}
TorchJitInput input;
auto jitObs = utils::tensorDictToTorchDict(obs, modelLocker_->device);
auto jitHid = utils::tensorDictToTorchDict(hidden_, modelLocker_->device);
input.push_back(jitObs);
input.push_back(jitHid);
int id = -1;
auto model = modelLocker_->getModel(&id);
auto output = model.get_method("act")(input).toTuple()->elements();
modelLocker_->releaseModel(id);
auto action = utils::iValueToTensorDict(output[0], torch::kCPU, true);
hidden_ = utils::iValueToTensorDict(output[1], torch::kCPU, true);
if (replayBuffer_ != nullptr) {
multiStepBuffer_.pushObsAndAction(obs, action);
}
numAct_ += batchsize_;
return action;
}