ppo_ewma/logger.py (415 lines of code) (raw):
import os
import sys
import shutil
import os.path as osp
import json
import time
import datetime
import tempfile
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import contextmanager
from functools import partial, wraps
from mpi4py import MPI
def mpi_weighted_mean(comm, local_name2valcount):
"""
Perform a weighted average over dicts that are each on a different node
Input: local_name2valcount: dict mapping key -> (value, count)
Returns: key -> mean
"""
local_name2valcount = {
name: (float(val), count)
for (name, (val, count)) in local_name2valcount.items()
}
all_name2valcount = comm.gather(local_name2valcount)
if comm.rank == 0:
name2sum = defaultdict(float)
name2count = defaultdict(float)
for n2vc in all_name2valcount:
for (name, (val, count)) in n2vc.items():
name2sum[name] += val * count
name2count[name] += count
return {name: name2sum[name] / name2count[name] for name in name2sum}
else:
return {}
class KVWriter(ABC):
@abstractmethod
def writekvs(self, kvs):
pass
@abstractmethod
def close(self):
pass
class SeqWriter(ABC):
@abstractmethod
def writeseq(self, seq):
pass
@abstractmethod
def close(self):
pass
class HumanOutputFormat(KVWriter, SeqWriter):
def __init__(self, filename_or_file):
if isinstance(filename_or_file, str):
self.file = open(filename_or_file, "wt")
self.own_file = True
else:
assert hasattr(filename_or_file, "read"), (
"expected file or str, got %s" % filename_or_file
)
self.file = filename_or_file
self.own_file = False
def writekvs(self, kvs):
# Create strings for printing
key2str = []
for (key, val) in sorted(kvs.items()):
if hasattr(val, "__float__"):
valstr = "%-8.3g" % val
else:
valstr = str(val)
key2str.append((self._truncate(key), self._truncate(valstr)))
# Find max widths
if len(key2str) == 0:
print("WARNING: tried to write empty key-value dict")
return
else:
keywidth = max(map(lambda kv: len(kv[0]), key2str))
valwidth = max(map(lambda kv: len(kv[1]), key2str))
# Write out the data
dashes = "-" * (keywidth + valwidth + 7)
lines = [dashes]
for (key, val) in sorted(key2str, key=lambda kv: kv[0].lower()):
lines.append(
"| %s%s | %s%s |"
% (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
)
lines.append(dashes)
self.file.write("\n".join(lines) + "\n")
# Flush the output to the file
self.file.flush()
def _truncate(self, s):
maxlen = 30
return s[: maxlen - 3] + "..." if len(s) > maxlen else s
def writeseq(self, seq):
seq = list(seq)
for (i, elem) in enumerate(seq):
self.file.write(elem)
if i < len(seq) - 1: # add space unless this is the last one
self.file.write(" ")
self.file.write("\n")
self.file.flush()
def close(self):
if self.own_file:
self.file.close()
class JSONOutputFormat(KVWriter):
def __init__(self, filename):
self.file = open(filename, "wt")
def writekvs(self, kvs):
for k, v in sorted(kvs.items()):
if hasattr(v, "dtype"):
kvs[k] = float(v)
self.file.write(json.dumps(kvs) + "\n")
self.file.flush()
def close(self):
self.file.close()
class CSVOutputFormat(KVWriter):
def __init__(self, filename):
self.file = open(filename, "w+t")
self.keys = []
self.sep = ","
def writekvs(self, kvs):
# Add our current row to the history
extra_keys = list(kvs.keys() - self.keys)
extra_keys.sort()
if extra_keys:
self.keys.extend(extra_keys)
self.file.seek(0)
lines = self.file.readlines()
self.file.seek(0)
for (i, k) in enumerate(self.keys):
if i > 0:
self.file.write(",")
self.file.write(k)
self.file.write("\n")
for line in lines[1:]:
self.file.write(line[:-1])
self.file.write(self.sep * len(extra_keys))
self.file.write("\n")
for (i, k) in enumerate(self.keys):
if i > 0:
self.file.write(",")
v = kvs.get(k)
if hasattr(v, "__float__"):
v = float(v)
if v is not None:
self.file.write(str(v))
self.file.write("\n")
self.file.flush()
def close(self):
self.file.close()
class TensorBoardOutputFormat(KVWriter):
"""
Dumps key/value pairs into TensorBoard's numeric format.
"""
def __init__(self, dir):
os.makedirs(dir, exist_ok=True)
self.dir = dir
self.step = 1
prefix = "events"
path = osp.join(osp.abspath(dir), prefix)
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.core.util import event_pb2
from tensorflow.python.util import compat
self.tf = tf
self.event_pb2 = event_pb2
self.pywrap_tensorflow = pywrap_tensorflow
self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
def writekvs(self, kvs):
def summary_val(k, v):
kwargs = {"tag": k, "simple_value": float(v)}
return self.tf.Summary.Value(**kwargs)
summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
event.step = (
self.step
) # is there any reason why you'd want to specify the step?
self.writer.WriteEvent(event)
self.writer.Flush()
self.step += 1
def close(self):
if self.writer:
self.writer.Close()
self.writer = None
def make_output_format(format, ev_dir, log_suffix=""):
os.makedirs(ev_dir, exist_ok=True)
if format == "stdout":
return HumanOutputFormat(sys.stdout)
elif format == "log":
return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
elif format == "json":
return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
elif format == "csv":
return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
elif format == "tensorboard":
return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
else:
raise ValueError("Unknown format specified: %s" % (format,))
# ================================================================
# API
# ================================================================
def logkv(key, val):
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
If called many times, last value will be used.
"""
get_current().logkv(key, val)
def logkv_mean(key, val):
"""
The same as logkv(), but if called many times, values averaged.
"""
get_current().logkv_mean(key, val)
def logkvs(d):
"""
Log a dictionary of key-value pairs
"""
for (k, v) in d.items():
logkv(k, v)
def logkvs_mean(d):
"""
Log a dictionary of key-value pairs with averaging over multiple calls
"""
for (k, v) in d.items():
logkv_mean(k, v)
def dumpkvs():
"""
Write all of the diagnostics from the current iteration
"""
return get_current().dumpkvs()
def getkvs():
return get_current().name2val
def log(*args):
"""
Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
"""
get_current().log(*args)
def warn(*args):
get_current().warn(*args)
def get_dir():
"""
Get directory that log files are being written to.
Will be None if there is no output directory (i.e., if you didn't call start)
"""
return get_current().get_dir()
@contextmanager
def profile_kv(scopename, sync_cuda=False):
if sync_cuda:
_sync_cuda()
logkey = "wait_" + scopename
tstart = time.time()
try:
yield
finally:
if sync_cuda:
_sync_cuda()
get_current().name2val[logkey] += time.time() - tstart
def _sync_cuda():
from torch import cuda
cuda.synchronize()
def profile(n):
"""
Usage:
@profile("my_func")
def my_func(): code
"""
def decorator_with_name(func, name):
@wraps(func)
def func_wrapper(*args, **kwargs):
with profile_kv(name):
return func(*args, **kwargs)
return func_wrapper
if callable(n):
return decorator_with_name(n, n.__name__)
elif isinstance(n, str):
return partial(decorator_with_name, name=n)
else:
raise NotImplementedError(
"profile should be called as either a bare decorator"
" or with a string (profiling name of a function) as an argument"
)
def dump_kwargs(func):
"""
Prints all keyword-only parameters of a function. Useful to print hyperparameters used.
Usage:
@logger.dump_kwargs
def create_policy(*, hp1, hp2, hp3): ...
or
logger.dump_kwargs(ppo.learn)(lr=60e-5, ...)
"""
def func_wrapper(*args, **kwargs):
import inspect, textwrap
sign = inspect.signature(func)
for k, p in sign.parameters.items():
if p.kind == inspect.Parameter.KEYWORD_ONLY:
default = "%15s (default)" % str(sign.parameters[k].default)
get_current().log(
"%s.%s: %15s = %s"
% (
func.__module__,
func.__qualname__,
k,
textwrap.shorten(
str(kwargs.get(k, default)),
width=70,
drop_whitespace=False,
placeholder="...",
),
)
)
return func(*args, **kwargs)
return func_wrapper
# ================================================================
# Backend
# ================================================================
# Pytorch explainer:
# If you keep a reference to a variable that depends on parameters, you
# keep around the whole computation graph. That causes an unpleasant surprise
# if you were just trying to log a scalar. We could cast to float, but
# that would require a synchronization, and it would be nice if logging
# didn't require the value to be available immediately. Therefore we
# detach the value at the point of logging, and only cast to float when
# dumping to the log file.
def get_current():
if not is_configured():
raise Exception("you must call logger.configure() before using logger")
return Logger.CURRENT
class Logger(object):
CURRENT = None # Current logger being used by the free functions above
def __init__(self, dir, output_formats, comm=None):
self.name2val = defaultdict(float) # values this iteration
self.name2cnt = defaultdict(int)
self.dir = dir
self.output_formats = output_formats
self.comm = comm
# Logging API, forwarded
# ----------------------------------------
def logkv(self, key, val):
if hasattr(val, "requires_grad"): # see "pytorch explainer" above
val = val.detach()
self.name2val[key] = val
def logkv_mean(self, key, val):
assert hasattr(val, "__float__")
if hasattr(val, "requires_grad"): # see "pytorch explainer" above
val = val.detach()
oldval, cnt = self.name2val[key], self.name2cnt[key]
self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
self.name2cnt[key] = cnt + 1
def dumpkvs(self):
if self.comm is None:
d = self.name2val
else:
d = mpi_weighted_mean(
self.comm,
{
name: (val, self.name2cnt.get(name, 1))
for (name, val) in self.name2val.items()
},
)
if self.comm.rank != 0:
d["dummy"] = 1 # so we don't get a warning about empty dict
out = d.copy() # Return the dict for unit testing purposes
for fmt in self.output_formats:
if self.comm.rank == 0:
fmt.writekvs(d)
self.name2val.clear()
self.name2cnt.clear()
return out
def log(self, *args):
self._do_log(args)
def warn(self, *args):
self._do_log(("[WARNING]", *args))
# Configuration
# ----------------------------------------
def get_dir(self):
return self.dir
def close(self):
for fmt in self.output_formats:
fmt.close()
# Misc
# ----------------------------------------
def _do_log(self, args):
for fmt in self.output_formats:
if isinstance(fmt, SeqWriter):
fmt.writeseq(map(str, args))
def configure(
dir: "(str|None) Local directory to write to" = None,
format_strs: "(str|None) list of formats" = None,
comm: "(MPI communicator | None) average numerical stats over comm" = None,
):
if dir is None:
if os.getenv("OPENAI_LOGDIR"):
dir = os.environ["OPENAI_LOGDIR"]
else:
dir = osp.join(
tempfile.gettempdir(),
datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
)
os.makedirs(dir, exist_ok=True)
# choose log suffix based on world rank because otherwise the files will collide
# if we split the world comm into different comms
if MPI.COMM_WORLD.rank == 0:
log_suffix = ""
else:
log_suffix = "-rank%03i" % MPI.COMM_WORLD.rank
if comm is None:
comm = MPI.COMM_WORLD
format_strs = format_strs or default_format_strs(comm.rank)
output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
log("logger: logging to %s" % dir)
def is_configured():
return Logger.CURRENT is not None
def default_format_strs(rank):
if rank == 0:
return ["stdout", "log", "csv"]
else:
return []
@contextmanager
def scoped_configure(dir=None, format_strs=None, comm=None):
prevlogger = Logger.CURRENT
configure(dir=dir, format_strs=format_strs, comm=comm)
try:
yield
finally:
Logger.CURRENT.close()
Logger.CURRENT = prevlogger
# ================================================================
def _demo():
configure()
log("hi")
dir = "/tmp/testlogging"
if os.path.exists(dir):
shutil.rmtree(dir)
configure(dir=dir)
logkv("a", 3)
logkv("b", 2.5)
dumpkvs()
logkv("b", -2.5)
logkv("a", 5.5)
dumpkvs()
log("^^^ should see a = 5.5")
logkv_mean("b", -22.5)
logkv_mean("b", -44.4)
logkv("a", 5.5)
dumpkvs()
log("^^^ should see b = -33.3")
logkv("b", -2.5)
dumpkvs()
# ================================================================
# Readers
# ================================================================
def read_json(fname):
import pandas
ds = []
with open(fname, "rt") as fh:
for line in fh:
ds.append(json.loads(line))
return pandas.DataFrame(ds)
def read_csv(fname):
import pandas
return pandas.read_csv(fname, index_col=None, comment="#")
def read_tb(path):
"""
path : a tensorboard file OR a directory, where we will find all TB files
of the form events.*
"""
import pandas
import numpy as np
from glob import glob
import tensorflow as tf
if osp.isdir(path):
fnames = glob(osp.join(path, "events.*"))
elif osp.basename(path).startswith("events."):
fnames = [path]
else:
raise NotImplementedError(
"Expected tensorboard file or directory containing them. Got %s" % path
)
tag2pairs = defaultdict(list)
maxstep = 0
for fname in fnames:
for summary in tf.train.summary_iterator(fname):
if summary.step > 0:
for v in summary.summary.value:
pair = (summary.step, v.simple_value)
tag2pairs[v.tag].append(pair)
maxstep = max(summary.step, maxstep)
data = np.empty((maxstep, len(tag2pairs)))
data[:] = np.nan
tags = sorted(tag2pairs.keys())
for (colidx, tag) in enumerate(tags):
pairs = tag2pairs[tag]
for (step, value) in pairs:
data[step - 1, colidx] = value
return pandas.DataFrame(data, columns=tags)
if __name__ == "__main__":
_demo()