def load()

in activemri/baselines/replay_buffer.py [0:0]


    def load(self, path: str, capacity: Optional[int] = None):
        """Loads the replay buffer from the specified path.

        Args:
            path(str): The path from where the memory will be loaded from.
            capacity(int): If provided, the buffer is created with this much capacity. This
                    value must be larger than the length of the stored tensors.
        """
        data = torch.load(path)
        self.position = data["position"]
        self.mean_obs = data["mean_obs"]
        self.std_obs = data["std_obs"]
        self._m2_obs = data["m2_obs"]
        self.count_seen = data["count_seen"]

        old_len = data["observations"].shape[0]
        if capacity is None:
            self.observations = data["observations"]
            self.actions = data["actions"]
            self.next_observations = data["next_observations"]
            self.rewards = data["rewards"]
            self.dones = data["dones"]
        else:
            assert capacity >= len(data["observations"])
            obs_shape = data["observations"].shape[1:]
            self.observations = torch.zeros(capacity, *obs_shape, dtype=torch.float32)
            self.actions = torch.zeros(capacity, dtype=torch.long)
            self.next_observations = torch.zeros(
                capacity, *obs_shape, dtype=torch.float32
            )
            self.rewards = torch.zeros(capacity, dtype=torch.float32)
            self.dones = torch.zeros(capacity, dtype=torch.bool)
            self.observations[:old_len] = data["observations"]
            self.actions[:old_len] = data["actions"]
            self.next_observations[:old_len] = data["next_observations"]
            self.rewards[:old_len] = data["rewards"]
            self.dones[:old_len] = data["dones"]

        return old_len