in activemri/baselines/replay_buffer.py [0:0]
def save(self, directory: str, name: str):
""" Saves all tensors and normalization info to file `directory/name` """
data = {
"observations": self.observations,
"actions": self.actions,
"next_observations": self.next_observations,
"rewards": self.rewards,
"dones": self.dones,
"position": self.position,
"mean_obs": self.mean_obs,
"std_obs": self.std_obs,
"m2_obs": self._m2_obs,
"count_seen": self.count_seen,
}
tmp_filename = tempfile.NamedTemporaryFile(delete=False, dir=directory)
try:
torch.save(data, tmp_filename)
except BaseException:
tmp_filename.close()
os.remove(tmp_filename.name)
raise
else:
tmp_filename.close()
full_path = os.path.join(directory, name)
os.rename(tmp_filename.name, full_path)
return full_path