def get_multiple_slots()

in rlstructures/deprecated/batchers/buffers.py [0:0]


    def get_multiple_slots(self, trajectories, erase=True):
        """
        Return the concatenation of multiple slots. This function is not well optimized and could be fasten
        """
        assert isinstance(trajectories, list) or isinstance(trajectories, tuple)
        assert isinstance(trajectories[0], list)
        assert isinstance(trajectories[0][0], int)
        # 1: Unify the size of all trajectories....
        max_l = 0
        for traj in trajectories:
            max_l = max(max_l, len(traj))
        ntrajectories = []
        for traj in trajectories:
            while not len(traj) == max_l:
                traj.append(None)
            ntrajectories.append(traj)

        # 2: Copy the content
        length = torch.zeros(len(ntrajectories)).to(self._device).long()
        tensors = []
        for k in range(max_l):
            idxs = [traj[k] for traj in ntrajectories]
            nidxs = []
            for _id in idxs:
                if _id is None:
                    nidxs.append(0)
                else:
                    nidxs.append(_id)
            nidxs = torch.tensor(nidxs).to(self._device)
            v = {k: self.buffers[k][nidxs] for k in self.buffers}
            pis = self.position_in_slot[nidxs]
            # Check that slots are full
            if k < max_l - 1:
                for i in range(len(pis)):
                    if not ntrajectories[i][k + 1] is None:
                        assert pis[i] == self.s_slots

            for i in range(len(pis)):
                if not ntrajectories[i][k] is None:
                    length[i] = length[i] + pis[i]

            tensors.append(v)
        ftrajectories = {
            k: torch.cat([t[k] for t in tensors], dim=1) for k in self.buffers
        }
        if erase:
            for k in trajectories:
                for kk in k:
                    if not kk is None:
                        self.set_free_slots(kk)

        return TemporalDictTensor(ftrajectories, length).shorten()