in activemri/baselines/replay_buffer.py [0:0]
def load(self, path: str, capacity: Optional[int] = None):
"""Loads the replay buffer from the specified path.
Args:
path(str): The path from where the memory will be loaded from.
capacity(int): If provided, the buffer is created with this much capacity. This
value must be larger than the length of the stored tensors.
"""
data = torch.load(path)
self.position = data["position"]
self.mean_obs = data["mean_obs"]
self.std_obs = data["std_obs"]
self._m2_obs = data["m2_obs"]
self.count_seen = data["count_seen"]
old_len = data["observations"].shape[0]
if capacity is None:
self.observations = data["observations"]
self.actions = data["actions"]
self.next_observations = data["next_observations"]
self.rewards = data["rewards"]
self.dones = data["dones"]
else:
assert capacity >= len(data["observations"])
obs_shape = data["observations"].shape[1:]
self.observations = torch.zeros(capacity, *obs_shape, dtype=torch.float32)
self.actions = torch.zeros(capacity, dtype=torch.long)
self.next_observations = torch.zeros(
capacity, *obs_shape, dtype=torch.float32
)
self.rewards = torch.zeros(capacity, dtype=torch.float32)
self.dones = torch.zeros(capacity, dtype=torch.bool)
self.observations[:old_len] = data["observations"]
self.actions[:old_len] = data["actions"]
self.next_observations[:old_len] = data["next_observations"]
self.rewards[:old_len] = data["rewards"]
self.dones[:old_len] = data["dones"]
return old_len