in ppo_ewma/log_save_helper.py [0:0]
def save(self):
if self.comm.rank != 0:
return
if self.save_mode == "last":
basename = "model"
elif self.save_mode == "all":
basename = f"model{self.save_idx:03d}"
elif self.save_mode == "none":
return
else:
raise NotImplementedError
suffix = f"_rank{MPI.COMM_WORLD.rank:03d}" if MPI.COMM_WORLD.rank != 0 else ""
basename += f"{suffix}.jd"
fname = os.path.join(logger.get_dir(), basename)
logger.log("Saving to ", fname, f"IC={self.total_interact_count}")
th.save(self.model, fname, pickle_protocol=-1)
self.save_idx += 1