ppo_ewma/log_save_helper.py (130 lines of code) (raw):

import os import time import resource import numpy as np import torch as th from . import logger from mpi4py import MPI def rcm(start, stop, modulus, mode="[)"): """ Interval contains multiple, where 'mode' specifies whether it's closed or open on either side This was very tricky to get right """ left_hit = start % modulus == 0 middle_hit = modulus * (start // modulus + 1) < stop # ^^^ modulus * (start // modulus + 1) is the smallest multiple of modulus that's # strictly greater than start right_hit = stop % modulus == 0 return (start < stop) and ( (left_hit and mode[0] == "[") or (middle_hit) or (right_hit and mode[1] == "]") ) class LogSaveHelper: def __init__( self, model: "(nn.Module)", ic_per_step: "(int) number of iteractions per logging step", comm: "(MPI.Comm)" = None, ic_per_save: "(int) save only after this many interactions" = 100_000, save_mode: "(str) last: keep last model, all: keep all}" = "none", t0: "(float) override training start timestamp" = None, log_callbacks: "(list) extra callbacks to run before self.log()" = None, log_new_eps: "(bool) whether to log statistics for new episodes from non-rolling buffer" = False, ): self.model = model self.comm = comm or MPI.COMM_WORLD self.ic_per_step = ic_per_step self.ic_per_save = ic_per_save self.save_mode = save_mode self.save_idx = 0 self.last_ic = 0 self.log_idx = 0 self.start_time = self.last_time = time.time() self.total_interact_count = 0 if ic_per_save > 0: self.save() self.start_time = self.last_time = t0 or time.time() self.log_callbacks = log_callbacks self.log_new_eps = log_new_eps self.roller_stats = {} def __call__(self): self.total_interact_count += self.ic_per_step assert self.total_interact_count > 0, "Should start counting at 1" will_save = (self.ic_per_save > 0) and rcm( self.last_ic + 1, self.total_interact_count + 1, self.ic_per_save ) self.log() if will_save: self.save() return True def gather_roller_stats(self, roller): self.roller_stats = { "EpRewMean": self._nanmean([] if roller is None else roller.recent_eprets), "EpLenMean": self._nanmean([] if roller is None else roller.recent_eplens), } if roller is not None and self.log_new_eps: assert roller.has_non_rolling_eps, "roller needs keep_non_rolling" ret_n, ret_mean, ret_std = self._nanmoments(roller.non_rolling_eprets) _len_n, len_mean, len_std = self._nanmoments(roller.non_rolling_eplens) roller.clear_non_rolling_episode_buf() self.roller_stats.update( { "NewEpNum": ret_n, "NewEpRewMean": ret_mean, "NewEpRewStd": ret_std, "NewEpLenMean": len_mean, "NewEpLenStd": len_std, } ) def log(self): if self.log_callbacks is not None: for callback in self.log_callbacks: callback() for k, v in self.roller_stats.items(): logger.logkv(k, v) logger.logkv("Misc/InteractCount", self.total_interact_count) cur_time = time.time() Δtime = cur_time - self.last_time Δic = self.total_interact_count - self.last_ic logger.logkv("Misc/TimeElapsed", cur_time - self.start_time) logger.logkv("IPS_total", Δic / Δtime) logger.logkv("del_time", Δtime) logger.logkv("Iter", self.log_idx) logger.logkv( "CpuMaxMemory", resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * 1000 ) if th.cuda.is_available(): logger.logkv("GpuMaxMemory", th.cuda.max_memory_allocated()) th.cuda.reset_max_memory_allocated() if self.comm.rank == 0: print("RCALL_LOGDIR: ", os.environ["RCALL_LOGDIR"]) logger.dumpkvs() self.last_time = cur_time self.last_ic = self.total_interact_count self.log_idx += 1 def save(self): if self.comm.rank != 0: return if self.save_mode == "last": basename = "model" elif self.save_mode == "all": basename = f"model{self.save_idx:03d}" elif self.save_mode == "none": return else: raise NotImplementedError suffix = f"_rank{MPI.COMM_WORLD.rank:03d}" if MPI.COMM_WORLD.rank != 0 else "" basename += f"{suffix}.jd" fname = os.path.join(logger.get_dir(), basename) logger.log("Saving to ", fname, f"IC={self.total_interact_count}") th.save(self.model, fname, pickle_protocol=-1) self.save_idx += 1 def _nanmean(self, xs): xs = _flatten(self.comm.allgather(xs)) return np.nan if len(xs) == 0 else np.mean(xs) def _nanmoments(self, xs, **kwargs): xs = _flatten(self.comm.allgather(xs)) return _nanmoments_local(xs, **kwargs) def _flatten(ls): return [el for sublist in ls for el in sublist] def _nanmoments_local(xs, ddof=1): n = len(xs) if n == 0: return n, np.nan, np.nan elif n == ddof: return n, np.mean(xs), np.nan else: return n, np.mean(xs), np.std(xs, ddof=ddof)