def put()

in salina/rl/replay_buffer.py [0:0]


    def put(self, workspace, time_size=None, padding=None):
        assert (
            workspace._all_variables_same_time_size()
        ), "Only works with workspace where all variables have the same time_size"
        T = workspace.time_size()
        if not time_size is None:
            assert time_size <= T
            n = T - time_size + 1
            if padding is None:
                padding = 1
            for t in range(0, n, padding):
                nworkspace = workspace.subtime(t, t + time_size)
                self.put(nworkspace)
            return

        all_tensors = {
            k: workspace.get_full(k).detach().to(self.device) for k in workspace.keys()
        }
        if self.variables is None:
            self.variables = {}
            for k, v in all_tensors.items():
                s = list(v.size())
                s[1] = self.max_size
                _s=copy.deepcopy(s)
                s[0]=_s[1]
                s[1]=_s[0]

                tensor = torch.zeros(*s, dtype=v.dtype, device=self.device)
                print(
                    "[ReplayBuffer] Var ",
                    k,
                    " size=",
                    s,
                    " dtype=",
                    v.dtype,
                    " device=",
                    self.device,
                )
                self.variables[k] = tensor
            self.is_full = False
            self.position = 0

        B = None
        arange = None
        indexes = None
        for k, v in all_tensors.items():
            if B is None:
                B = v.size()[1]
            B = min(self.position + B, self.max_size)
            B = B - self.position
            if indexes is None:
                indexes = torch.arange(B) + self.position
                arange = torch.arange(B)
            indexes = indexes.to(v.device)
            arange = arange.to(v.device)
            self.variables[k][indexes] = v.detach().transpose(0,1)

        self.position = self.position + B
        if self.position >= self.max_size:
            self.position = 0
            self.is_full = True