in hucc/replaybuffer.py [0:0]
def put_row(self, data: Dict[str, th.Tensor]) -> None:
if self._b is None: # Lazy auto-initialization
assert len(data) > 0
if self.device is None:
self.device = [v for v in data.values()][0].device
self._init_buffers(
self.specsFromTensors({k: v[0] for k, v in data.items()})
)
if self._b is None:
raise RuntimeError() # to make the linter happy
bsz = self.interleave
for k, v in data.items():
assert v.shape[0] == bsz
end = (self.start + self.size) % self.max
idx = th.arange(end, end + bsz) % self.max
for k in ['obs', 'next_obs']:
assert (k in data) == (k in self._b), 'No buffer for f"{k}"'
if k in data:
assert k in self._b
self._b[k][idx] = data[k].detach().to(self._b[k])
if 'action' in data:
assert 'action' in self._b, 'No buffer for "action"'
self._b['action'][idx] = (
data['action'].detach().to(self._b['action'])
)
for k in set(data.keys()) - set(['obs', 'next_obs', 'action']):
assert k in self._b, f'No buffer for "{k}"'
self._b[k][idx] = data[k].squeeze().detach().to(self._b[k])
if self.size + bsz > self.max:
self.start = (self.start + bsz) % self.max
self.size = self.max
else:
self.size += bsz
self.tlen = min(self.tlen + 1, self.max_tlen)