in rela/transition.cc [120:158]
FFTransition FFTransition::makeBatch(
const std::vector<FFTransition>& transitions, const std::string& device) {
std::vector<TensorDict> obsVec;
std::vector<TensorDict> actionVec;
std::vector<torch::Tensor> rewardVec;
std::vector<torch::Tensor> terminalVec;
std::vector<torch::Tensor> bootstrapVec;
std::vector<TensorDict> nextObsVec;
for (size_t i = 0; i < transitions.size(); i++) {
obsVec.push_back(transitions[i].obs);
actionVec.push_back(transitions[i].action);
rewardVec.push_back(transitions[i].reward);
terminalVec.push_back(transitions[i].terminal);
bootstrapVec.push_back(transitions[i].bootstrap);
nextObsVec.push_back(transitions[i].nextObs);
}
FFTransition batch;
batch.obs = tensor_dict::stack(obsVec, 0);
batch.action = tensor_dict::stack(actionVec, 0);
batch.reward = torch::stack(rewardVec, 0);
batch.terminal = torch::stack(terminalVec, 0);
batch.bootstrap = torch::stack(bootstrapVec, 0);
batch.nextObs = tensor_dict::stack(nextObsVec, 0);
if (device != "cpu") {
auto d = torch::Device(device);
auto toDevice = [&](const torch::Tensor& t) { return t.to(d); };
batch.obs = tensor_dict::apply(batch.obs, toDevice);
batch.action = tensor_dict::apply(batch.action, toDevice);
batch.reward = batch.reward.to(d);
batch.terminal = batch.terminal.to(d);
batch.bootstrap = batch.bootstrap.to(d);
batch.nextObs = tensor_dict::apply(batch.nextObs, toDevice);
}
return batch;
}