examples/rsa/search_inference.py (136 lines of code) (raw):

# Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 """ Inference algorithms and utilities used in the RSA example models. Adapted from: http://dippl.org/chapters/03-enumeration.html """ import collections import torch import queue import functools import pyro.distributions as dist import pyro.poutine as poutine from pyro.infer.abstract_infer import TracePosterior from pyro.poutine.runtime import NonlocalExit def memoize(fn=None, **kwargs): if fn is None: return lambda _fn: memoize(_fn, **kwargs) return functools.lru_cache(**kwargs)(fn) class HashingMarginal(dist.Distribution): """ :param trace_dist: a TracePosterior instance representing a Monte Carlo posterior Marginal histogram distribution. Turns a TracePosterior object into a Distribution over the return values of the TracePosterior's model. """ def __init__(self, trace_dist, sites=None): assert isinstance(trace_dist, TracePosterior), \ "trace_dist must be trace posterior distribution object" if sites is None: sites = "_RETURN" assert isinstance(sites, (str, list)), \ "sites must be either '_RETURN' or list" self.sites = sites super().__init__() self.trace_dist = trace_dist has_enumerate_support = True @memoize(maxsize=10) def _dist_and_values(self): # XXX currently this whole object is very inefficient values_map, logits = collections.OrderedDict(), collections.OrderedDict() for tr, logit in zip(self.trace_dist.exec_traces, self.trace_dist.log_weights): if isinstance(self.sites, str): value = tr.nodes[self.sites]["value"] else: value = {site: tr.nodes[site]["value"] for site in self.sites} if not torch.is_tensor(logit): logit = torch.tensor(logit) if torch.is_tensor(value): value_hash = hash(value.cpu().contiguous().numpy().tobytes()) elif isinstance(value, dict): value_hash = hash(self._dict_to_tuple(value)) else: value_hash = hash(value) if value_hash in logits: # Value has already been seen. logits[value_hash] = dist.util.logsumexp(torch.stack([logits[value_hash], logit]), dim=-1) else: logits[value_hash] = logit values_map[value_hash] = value logits = torch.stack(list(logits.values())).contiguous().view(-1) logits = logits - dist.util.logsumexp(logits, dim=-1) d = dist.Categorical(logits=logits) return d, values_map def sample(self): d, values_map = self._dist_and_values() ix = d.sample() return list(values_map.values())[ix] def log_prob(self, val): d, values_map = self._dist_and_values() if torch.is_tensor(val): value_hash = hash(val.cpu().contiguous().numpy().tobytes()) elif isinstance(val, dict): value_hash = hash(self._dict_to_tuple(val)) else: value_hash = hash(val) return d.log_prob(torch.tensor([list(values_map.keys()).index(value_hash)])) def enumerate_support(self): d, values_map = self._dist_and_values() return list(values_map.values())[:] def _dict_to_tuple(self, d): """ Recursively converts a dictionary to a list of key-value tuples Only intended for use as a helper function inside HashingMarginal!! May break when keys cant be sorted, but that is not an expected use-case """ if isinstance(d, dict): return tuple([(k, self._dict_to_tuple(d[k])) for k in sorted(d.keys())]) else: return d def _weighted_mean(self, value, dim=0): weights = self._log_weights.reshape([-1] + (value.dim() - 1) * [1]) max_weight = weights.max(dim=dim)[0] relative_probs = (weights - max_weight).exp() return (value * relative_probs).sum(dim=dim) / relative_probs.sum(dim=dim) @property def mean(self): samples = torch.stack(list(self._dist_and_values()[1].values())) return self._weighted_mean(samples) @property def variance(self): samples = torch.stack(list(self._dist_and_values()[1].values())) deviation_squared = torch.pow(samples - self.mean, 2) return self._weighted_mean(deviation_squared) ######################## # Exact Search inference ######################## class Search(TracePosterior): """ Exact inference by enumerating over all possible executions """ def __init__(self, model, max_tries=int(1e6), **kwargs): self.model = model self.max_tries = max_tries super().__init__(**kwargs) def _traces(self, *args, **kwargs): q = queue.Queue() q.put(poutine.Trace()) p = poutine.trace( poutine.queue(self.model, queue=q, max_tries=self.max_tries)) while not q.empty(): tr = p.get_trace(*args, **kwargs) yield tr, tr.log_prob_sum() ############################################### # Best-first Search Inference ############################################### def pqueue(fn, queue): def sample_escape(tr, site): return (site["name"] not in tr) and \ (site["type"] == "sample") and \ (not site["is_observed"]) def _fn(*args, **kwargs): for i in range(int(1e6)): assert not queue.empty(), \ "trying to get() from an empty queue will deadlock" priority, next_trace = queue.get() try: ftr = poutine.trace(poutine.escape(poutine.replay(fn, next_trace), functools.partial(sample_escape, next_trace))) return ftr(*args, **kwargs) except NonlocalExit as site_container: site_container.reset_stack() for tr in poutine.util.enum_extend(ftr.trace.copy(), site_container.site): # add a little bit of noise to the priority to break ties... queue.put((tr.log_prob_sum().item() - torch.rand(1).item() * 1e-2, tr)) raise ValueError("max tries ({}) exceeded".format(str(1e6))) return _fn class BestFirstSearch(TracePosterior): """ Inference by enumerating executions ordered by their probabilities. Exact (and results equivalent to Search) if all executions are enumerated. """ def __init__(self, model, num_samples=None, **kwargs): if num_samples is None: num_samples = 100 self.num_samples = num_samples self.model = model super().__init__(**kwargs) def _traces(self, *args, **kwargs): q = queue.PriorityQueue() # add a little bit of noise to the priority to break ties... q.put((torch.zeros(1).item() - torch.rand(1).item() * 1e-2, poutine.Trace())) q_fn = pqueue(self.model, queue=q) for i in range(self.num_samples): if q.empty(): # num_samples was too large! break tr = poutine.trace(q_fn).get_trace(*args, **kwargs) # XXX should block yield tr, tr.log_prob_sum()