ppo_ewma/minibatch_optimize.py (82 lines of code) (raw):

import torch as th from . import logger from .tree_util import tree_map from . import torch_util as tu def _fmt_row(width, row, header=False): out = " | ".join(_fmt_item(x, width) for x in row) if header: out = out + "\n" + "-" * len(out) return out def _fmt_item(x, l): if th.is_tensor(x): assert x.dim() == 0 x = float(x) if isinstance(x, float): v = abs(x) if (v < 1e-4 or v > 1e4) and v > 0: rep = "%7.2e" % x else: rep = "%7.5f" % x else: rep = str(x) return " " * (l - len(rep)) + rep class LossDictPrinter: """ Helps with incrementally printing out stats row by row in a formatted table """ def __init__(self): self.printed_header = False def print_row(self, d): if not self.printed_header: logger.log(_fmt_row(12, d.keys())) self.printed_header = True logger.log(_fmt_row(12, d.values())) def minibatch_optimize( train_fn: "function (dict) -> dict called on each minibatch that returns training stats", tensordict: "Dict[str->th.Tensor]", *, nepoch: "(int) number of epochs over dataset", nminibatch: "(int) number of minibatch per epoch", comm: "(MPI.Comm) MPI communicator", verbose: "(bool) print detailed stats" = False, epoch_fn: "function () -> dict to be called each epoch" = None, ): ldp = LossDictPrinter() epoch_dicts = [] for _ in range(nepoch): mb_dicts = [ train_fn(**mb) for mb in minibatch_gen(tensordict, nminibatch=nminibatch) ] local_dict = {k: float(v) for (k, v) in dict_mean(mb_dicts).items()} if epoch_fn is not None: local_dict.update(dict_mean(epoch_fn())) global_dict = dict_mean(comm.allgather(local_dict)) epoch_dicts.append(global_dict) if verbose: ldp.print_row(global_dict) return epoch_dicts def dict_mean(ds): return {k: sum(d[k] for d in ds) / len(ds) for k in ds[0].keys()} def to_th_device(x): assert th.is_tensor(x), "to_th_device should only be applied to torch tensors" dtype = th.float32 if x.dtype == th.float64 else None return x.to(tu.dev(), dtype=dtype) def minibatch_gen(data, *, batch_size=None, nminibatch=None, forever=False): assert (batch_size is None) != ( nminibatch is None ), "only one of batch_size or nminibatch should be specified" ntrain = tu.batch_len(data) if nminibatch is None: nminibatch = max(ntrain // batch_size, 1) if nminibatch > ntrain: assert ( nminibatch % ntrain == 0 ), "nminibatch must be a multiple of ntrain if it is larger" nchunks = nminibatch // ntrain nminibatch = ntrain else: nchunks = 1 while True: for mbinds in th.chunk(th.randperm(ntrain), nminibatch): mb = tree_map(to_th_device, tu.tree_slice(data, mbinds)) for chunknum in range(nchunks): # raises an IndexError if nminibatch is incompatible with num_envs and nstep yield tree_map(lambda x: th.chunk(x, nchunks, dim=1)[chunknum], mb) if not forever: return