in rlstructures/deprecated/batchers/buffers.py [0:0]
def get_multiple_slots(self, trajectories, erase=True):
"""
Return the concatenation of multiple slots. This function is not well optimized and could be fasten
"""
assert isinstance(trajectories, list) or isinstance(trajectories, tuple)
assert isinstance(trajectories[0], list)
assert isinstance(trajectories[0][0], int)
# 1: Unify the size of all trajectories....
max_l = 0
for traj in trajectories:
max_l = max(max_l, len(traj))
ntrajectories = []
for traj in trajectories:
while not len(traj) == max_l:
traj.append(None)
ntrajectories.append(traj)
# 2: Copy the content
length = torch.zeros(len(ntrajectories)).to(self._device).long()
tensors = []
for k in range(max_l):
idxs = [traj[k] for traj in ntrajectories]
nidxs = []
for _id in idxs:
if _id is None:
nidxs.append(0)
else:
nidxs.append(_id)
nidxs = torch.tensor(nidxs).to(self._device)
v = {k: self.buffers[k][nidxs] for k in self.buffers}
pis = self.position_in_slot[nidxs]
# Check that slots are full
if k < max_l - 1:
for i in range(len(pis)):
if not ntrajectories[i][k + 1] is None:
assert pis[i] == self.s_slots
for i in range(len(pis)):
if not ntrajectories[i][k] is None:
length[i] = length[i] + pis[i]
tensors.append(v)
ftrajectories = {
k: torch.cat([t[k] for t in tensors], dim=1) for k in self.buffers
}
if erase:
for k in trajectories:
for kk in k:
if not kk is None:
self.set_free_slots(kk)
return TemporalDictTensor(ftrajectories, length).shorten()