in mtrl/replay_buffer.py [0:0]
def load(self, save_dir):
chunks = os.listdir(save_dir)
chunks = sorted(chunks, key=lambda x: int(x.split("_")[0]))
start = 0
for chunk in chunks:
path = os.path.join(save_dir, chunk)
try:
payload = torch.load(path)
end = start + payload[0].shape[0]
if end > self.capacity:
# this condition is added for resuming some very old experiments.
# This condition should not be needed with the new experiments
# and should be removed going forward.
select_till_index = payload[0].shape[0] - (end - self.capacity)
end = start + select_till_index
else:
select_till_index = payload[0].shape[0]
self.env_obses[start:end] = payload[0][:select_till_index]
self.next_env_obses[start:end] = payload[1][:select_till_index]
self.actions[start:end] = payload[2][:select_till_index]
self.rewards[start:end] = payload[3][:select_till_index]
self.not_dones[start:end] = payload[4][:select_till_index]
self.task_obs[start:end] = payload[5][:select_till_index]
self.idx = end - 1
start = end
print(f"Loaded replay buffer from path: {path})")
except EOFError as e:
print(
f"Skipping loading replay buffer from path: {path} due to error: {e}"
)
self.last_save = self.idx