pyro/infer/mcmc/logger.py (125 lines of code) (raw):
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import os
import sys
from collections import OrderedDict
from tqdm import tqdm
from tqdm.auto import tqdm as tqdm_nb
try:
get_ipython
ipython_env = True
except NameError:
ipython_env = False
# Identifiers to distinguish between diagnostic messages for progress bars
# vs. logging output. Useful when using QueueHandler in multiprocessing.
LOG_MSG = "LOG"
DIAGNOSTIC_MSG = "DIAGNOSTICS"
# Following compatibility code is for Python 2 (available in Python 3.2+).
# Source: https://github.com/python/cpython/blob/master/Lib/logging/handlers.py
#
# Copyright 2001-2016 by Vinay Sajip. All Rights Reserved.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose and without fee is hereby granted,
# provided that the above copyright notice appear in all copies and that
# both that copyright notice and this permission notice appear in
# supporting documentation, and that the name of Vinay Sajip
# not be used in advertising or publicity pertaining to distribution
# of the software without specific, written prior permission.
# VINAY SAJIP DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING
# ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
# VINAY SAJIP BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR
# ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER
# IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
class ProgressBar:
"""
Initialize progress bars using :class:`~tqdm.tqdm`.
:param int warmup_steps: Number of warmup steps.
:param int num_samples: Number of MCMC samples.
:param int min_width: Minimum column width of the bar.
:param int max_width: Maximum column width of the bar.
:param bool disable: Disable progress bar.
:param int num_bars: Number of progress bars to initialize.
If multiple bars are initialized, they need to be separately
updated via the ``pos`` kwarg.
"""
def __init__(self, warmup_steps, num_samples, min_width=80, max_width=120,
disable=False, num_bars=1):
total_steps = warmup_steps + num_samples
# Disable progress bar in "CI"
# (see https://github.com/travis-ci/travis-ci/issues/1337).
disable = disable or "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ
bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}, {rate_fmt}{postfix}]"
pbar_cls = tqdm_nb if num_bars > 1 and ipython_env else tqdm
self.progress_bars = []
for i in range(num_bars):
description = "Warmup" if num_bars == 1 else "Warmup [{}]".format(i + 1)
pbar = pbar_cls(total=total_steps, desc=description, bar_format=bar_format,
position=i, file=sys.stderr, disable=disable)
# Assume reasonable values when terminal width not available
if getattr(pbar, "ncols", None) is not None:
pbar.ncols = max(min_width, pbar.ncols)
pbar.ncols = min(max_width, pbar.ncols)
self.progress_bars.append(pbar)
self.disable = disable
self.ipython_env = ipython_env
def __enter__(self):
return self
def __exit__(self, *exc):
self.close()
return False
def set_description(self, *args, **kwargs):
pos = kwargs.pop("pos", 0)
if not self.disable:
self.progress_bars[pos].set_description(*args, **kwargs)
def set_postfix(self, *args, **kwargs):
pos = kwargs.pop("pos", 0)
if not self.disable:
self.progress_bars[pos].set_postfix(*args, **kwargs)
def update(self, *args, **kwargs):
pos = kwargs.pop("pos", 0)
if not self.disable:
self.progress_bars[pos].update(*args, **kwargs)
def close(self):
for pbar in self.progress_bars:
pbar.close()
# Required to not overwrite multiple progress bars on exit.
if not self.ipython_env and not self.disable:
sys.stderr.write("\n" * len(self.progress_bars))
class QueueHandler(logging.Handler):
"""
This handler sends events to a queue. Typically, it would be used together
with a multiprocessing Queue to centralise logging to file in one process
(in a multi-process application), so as to avoid file write contention
between processes.
This code is new in Python 3.2, but this class can be copy pasted into
user code for use with earlier Python versions.
"""
def __init__(self, queue):
"""
Initialise an instance, using the passed queue.
"""
logging.Handler.__init__(self)
self.queue = queue
def enqueue(self, record):
"""
Enqueue a record.
The base implementation uses put_nowait. You may want to override
this method if you want to use blocking, timeouts or custom queue
implementations.
"""
self.queue.put_nowait(record)
def prepare(self, record):
"""
Prepares a record for queuing. The object returned by this method is
enqueued.
The base implementation formats the record to merge the message
and arguments, and removes unpickleable items from the record
in-place.
You might want to override this method if you want to convert
the record to a dict or JSON string, or send a modified copy
of the record while leaving the original intact.
"""
record.msg = self.format(record)
record.args = None
record.exc_info = None
return record
def emit(self, record):
"""
Emit a record.
Writes the LogRecord to the queue, preparing it for pickling first.
"""
try:
self.enqueue(self.prepare(record))
except Exception:
self.handleError(record)
class TqdmHandler(logging.StreamHandler):
"""
Handler that synchronizes the log output with the
:class:`~tqdm.tqdm` progress bar.
"""
def emit(self, record):
try:
msg = self.format(record)
self.flush()
tqdm.write(msg, file=sys.stderr)
except (KeyboardInterrupt, SystemExit) as e:
raise e
except Exception:
self.handleError(record)
class MCMCLoggingHandler(logging.Handler):
"""
Main logging handler used by :class:`~pyro.infer.mcmc`,
to handle both progress bar updates and regular `logging`
messages.
:param log_handler: default log handler for logging
output.
:param progress_bar: If provided, diagnostic information
is updated using the bar.
"""
def __init__(self, log_handler, progress_bar=None):
logging.Handler.__init__(self)
self.log_handler = log_handler
self.progress_bar = progress_bar
def emit(self, record):
try:
if self.progress_bar and record.msg_type == DIAGNOSTIC_MSG:
diagnostics = json.loads(record.getMessage(),
object_pairs_hook=OrderedDict)
self.progress_bar.set_postfix(diagnostics, refresh=False)
self.progress_bar.update()
else:
self.log_handler.handle(record)
except (KeyboardInterrupt, SystemExit) as e:
raise e
except Exception:
self.handleError(record)
class MetadataFilter(logging.Filter):
"""
Adds auxiliary information to log records, like `logger_id` and
`msg_type`.
"""
def __init__(self, logger_id):
self.logger_id = logger_id
super().__init__()
def filter(self, record):
record.logger_id = self.logger_id
if not getattr(record, "msg_type", None):
record.msg_type = LOG_MSG
return True
def initialize_logger(logger, logger_id, progress_bar=None, log_queue=None):
"""
Initialize logger for the :class:`pyro.infer.mcmc` module.
:param logger: logger instance.
:param str logger_id: identifier for the log record,
e.g. chain id in case of multiple samplers.
:param progress_bar: a :class:`tqdm.tqdm` instance.
"""
# Reset handler with new `progress_bar`.
logger.handlers = []
logger.propagate = False
if log_queue:
handler = QueueHandler(log_queue)
format = "[%(levelname)s %(msg_type)s %(logger_id)s]%(message)s"
progress_bar = None
elif progress_bar:
format = "%(levelname).1s \t %(message)s"
handler = TqdmHandler()
else:
raise ValueError("Logger cannot be initialized without a "
"valid handler.")
handler.setFormatter(logging.Formatter(format))
logging_handler = MCMCLoggingHandler(handler, progress_bar)
logging_handler.addFilter(MetadataFilter(logger_id))
logger.addHandler(logging_handler)
return logger