in svg/replay_buffer.py [0:0]
def load_data(self, save_dir):
def parse_chunk(chunk):
start, end = [int(x) for x in chunk.split('.')[0].split('_')]
return (start, end)
self.idx = 0
chunks = os.listdir(save_dir)
chunks = filter(lambda fname: 'stats' not in fname, chunks)
chunks = sorted(chunks, key=lambda x: int(x.split('_')[0]))
self.full = self.global_idx > self.capacity
global_beginning = self.global_idx - self.capacity if self.full else 0
for chunk in chunks:
global_start, global_end = parse_chunk(chunk)
if global_start >= self.global_idx:
continue
start = global_start - global_beginning
end = global_end - global_beginning
if end <= 0:
continue
path = os.path.join(save_dir, chunk)
payload = torch.load(path)
if start < 0:
payload = [x[-start:] for x in payload]
start = 0
assert self.idx == start
obses = payload[0]
next_obses = payload[1]
self.obses[start:end] = obses
self.next_obses[start:end] = next_obses
self.actions[start:end] = payload[2]
self.rewards[start:end] = payload[3]
self.not_dones[start:end] = payload[4]
self.not_dones_no_max[start:end] = payload[5]
self.idx = end
self.last_save = self.idx
if self.full:
assert self.idx == self.capacity
self.idx = 0
last_idx = self.capacity if self.full else self.idx
self.done_idxs = SortedSet(np.where(1.-self.not_dones[:last_idx])[0])