in rela/types.cc [140:182]
RNNTransition RNNTransition::makeBatch(const std::vector<RNNTransition>& transitions,
const std::string& device) {
TensorVecDict obsVec;
TensorVecDict h0Vec;
TensorVecDict actionVec;
std::vector<torch::Tensor> rewardVec;
std::vector<torch::Tensor> terminalVec;
std::vector<torch::Tensor> bootstrapVec;
std::vector<torch::Tensor> seqLenVec;
for (size_t i = 0; i < transitions.size(); i++) {
utils::tensorVecDictAppend(obsVec, transitions[i].obs);
utils::tensorVecDictAppend(h0Vec, transitions[i].h0);
utils::tensorVecDictAppend(actionVec, transitions[i].action);
rewardVec.push_back(transitions[i].reward);
terminalVec.push_back(transitions[i].terminal);
bootstrapVec.push_back(transitions[i].bootstrap);
seqLenVec.push_back(transitions[i].seqLen);
}
RNNTransition batch;
batch.obs = utils::tensorDictJoin(obsVec, 1);
batch.h0 = utils::tensorDictJoin(h0Vec, 1); // 1 is batch for rnn hid
batch.action = utils::tensorDictJoin(actionVec, 1);
batch.reward = torch::stack(rewardVec, 1);
batch.terminal = torch::stack(terminalVec, 1);
batch.bootstrap = torch::stack(bootstrapVec, 1);
batch.seqLen = torch::stack(seqLenVec, 0).squeeze(1);
if (device != "cpu") {
auto d = torch::Device(device);
auto toDevice = [&](const torch::Tensor& t) { return t.to(d); };
batch.obs = utils::tensorDictApply(batch.obs, toDevice);
batch.h0 = utils::tensorDictApply(batch.h0, toDevice);
batch.action = utils::tensorDictApply(batch.action, toDevice);
batch.reward = batch.reward.to(d);
batch.terminal = batch.terminal.to(d);
batch.bootstrap = batch.bootstrap.to(d);
batch.seqLen = batch.seqLen.to(d);
}
return batch;
}