FFTransition FFTransition::makeBatch()

in rela/types.cc [8:46]


FFTransition FFTransition::makeBatch(const std::vector<FFTransition>& transitions,
                                     const std::string& device) {
  TensorVecDict obsVec;
  TensorVecDict actionVec;
  std::vector<torch::Tensor> rewardVec;
  std::vector<torch::Tensor> terminalVec;
  std::vector<torch::Tensor> bootstrapVec;
  TensorVecDict nextObsVec;

  for (size_t i = 0; i < transitions.size(); i++) {
    utils::tensorVecDictAppend(obsVec, transitions[i].obs);
    utils::tensorVecDictAppend(actionVec, transitions[i].action);
    rewardVec.push_back(transitions[i].reward);
    terminalVec.push_back(transitions[i].terminal);
    bootstrapVec.push_back(transitions[i].bootstrap);
    utils::tensorVecDictAppend(nextObsVec, transitions[i].nextObs);
  }

  FFTransition batch;
  batch.obs = utils::tensorDictJoin(obsVec, 0);
  batch.action = utils::tensorDictJoin(actionVec, 0);
  batch.reward = torch::stack(rewardVec, 0);
  batch.terminal = torch::stack(terminalVec, 0);
  batch.bootstrap = torch::stack(bootstrapVec, 0);
  batch.nextObs = utils::tensorDictJoin(nextObsVec, 0);

  if (device != "cpu") {
    auto d = torch::Device(device);
    auto toDevice = [&](const torch::Tensor& t) { return t.to(d); };
    batch.obs = utils::tensorDictApply(batch.obs, toDevice);
    batch.action = utils::tensorDictApply(batch.action, toDevice);
    batch.reward = batch.reward.to(d);
    batch.terminal = batch.terminal.to(d);
    batch.bootstrap = batch.bootstrap.to(d);
    batch.nextObs = utils::tensorDictApply(batch.nextObs, toDevice);
  }

  return batch;
}