hucc/replaybuffer.py [223:276]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        idx = sample_indices()
        while stack_obs and stack_obs > 1 and 'terminal' in self._b:
            raise NotImplementedError('Let\'s not use this feature...')
            # Select indices which won't include terminal states in [0,...,stack-1]
            found_term = False
            for i in range(stack_obs):
                if (
                    self._b['terminal']
                    .index_select(0, (idx + i * ilv) % self.max)
                    .any()
                ):
                    log.debug('Found terminal state, resampling')
                    found_term = True
                    break
            if not found_term:
                break
            idx = sample_indices()

        batch: Dict[str, th.Tensor] = dict()
        batch['_idx'] = idx

        if stack_obs:
            if 'obs' in self._b:
                batch['obs'] = th.stack(
                    [
                        self._b['obs'].index_select(
                            0, (idx + i * ilv) % self.max
                        )
                        for i in range(stack_obs)
                    ],
                    dim=1,
                )
            if 'next_obs' in self._b:
                batch['next_obs'] = th.stack(
                    [
                        self._b['next_obs'].index_select(
                            0, (idx + (i + 1) * ilv) % self.max
                        )
                        for i in range(stack_obs)
                    ],
                    dim=1,
                )
            for k in set(self._b.keys()) - set(['obs', 'next_obs']):
                batch[k] = self._b[k].index_select(
                    0, (idx + stack_obs - 1) % self.max
                )
        else:
            idx_mod = idx % self.max
            for k in set(self._b.keys()):
                batch[k] = self._b[k].index_select(0, idx_mod)

        for k, v in batch.items():
            batch[k] = v.to(device)
        return batch
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



hucc/replaybuffer.py [305:358]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        idx = sample_indices()
        while stack_obs and stack_obs > 1 and 'terminal' in self._b:
            raise NotImplementedError('Let\'s not use this feature...')
            # Select indices which won't include terminal states in [0,...,stack-1]
            found_term = False
            for i in range(stack_obs):
                if (
                    self._b['terminal']
                    .index_select(0, (idx + i * ilv) % self.max)
                    .any()
                ):
                    log.debug('Found terminal state, resampling')
                    found_term = True
                    break
            if not found_term:
                break
            idx = sample_indices()

        batch: Dict[str, th.Tensor] = dict()
        batch['_idx'] = idx

        if stack_obs:
            if 'obs' in self._b:
                batch['obs'] = th.stack(
                    [
                        self._b['obs'].index_select(
                            0, (idx + i * ilv) % self.max
                        )
                        for i in range(stack_obs)
                    ],
                    dim=1,
                )
            if 'next_obs' in self._b:
                batch['next_obs'] = th.stack(
                    [
                        self._b['next_obs'].index_select(
                            0, (idx + (i + 1) * ilv) % self.max
                        )
                        for i in range(stack_obs)
                    ],
                    dim=1,
                )
            for k in set(self._b.keys()) - set(['obs', 'next_obs']):
                batch[k] = self._b[k].index_select(
                    0, (idx + stack_obs - 1) % self.max
                )
        else:
            idx_mod = idx % self.max
            for k in set(self._b.keys()):
                batch[k] = self._b[k].index_select(0, idx_mod)

        for k, v in batch.items():
            batch[k] = v.to(device)
        return batch
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



