virtual TensorDict act()

in rela/r2d2_actor.h [221:249]


  virtual TensorDict act(TensorDict& obs) override {
    torch::NoGradGuard ng;
    assert(!hidden_.empty());

    if (replayBuffer_ != nullptr) {
      historyHidden_.push_back(hidden_);
    }

    TorchJitInput input;
    auto jitObs = utils::tensorDictToTorchDict(obs, modelLocker_->device);
    auto jitHid = utils::tensorDictToTorchDict(hidden_, modelLocker_->device);
    input.push_back(jitObs);
    input.push_back(jitHid);

    int id = -1;
    auto model = modelLocker_->getModel(&id);
    auto output = model.get_method("act")(input).toTuple()->elements();
    modelLocker_->releaseModel(id);

    auto action = utils::iValueToTensorDict(output[0], torch::kCPU, true);
    hidden_ = utils::iValueToTensorDict(output[1], torch::kCPU, true);

    if (replayBuffer_ != nullptr) {
      multiStepBuffer_.pushObsAndAction(obs, action);
    }

    numAct_ += batchsize_;
    return action;
  }