pyro/ops/rings.py (186 lines of code) (raw):
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import weakref
from abc import ABCMeta, abstractmethod
import torch
from pyro.ops.einsum import contract
from pyro.ops.einsum.adjoint import SAMPLE_SYMBOL, Backward
from pyro.util import ignore_jit_warnings
class Ring(object, metaclass=ABCMeta):
"""
Abstract tensor ring class.
Each tensor ring class has a notion of ``dims`` that can be sum-contracted
out, and a notion of ``ordinal`` that represents a set of plate dimensions
that can be broadcasted-up or product-contracted out.
Implementations should cache intermediate results to be compatible with
:func:`~opt_einsum.shared_intermediates`.
Dims are characters (string or unicode).
Ordinals are frozensets of characters.
:param dict cache: an optional :func:`~opt_einsum.shared_intermediates`
cache.
"""
def __init__(self, cache=None):
self._cache = {} if cache is None else cache
def _hash_by_id(self, tensor):
"""
Returns the id of a tensor and saves the tensor so that this id can be
used as a key in the cache without risk of the id being recycled.
"""
result = id(tensor)
assert self._cache.setdefault(('tensor', result), tensor) is tensor
return result
@abstractmethod
def sumproduct(self, terms, dims):
"""
Multiply all ``terms`` together, then sum-contract out all ``dims``
from the result.
:param list terms: a list of tensors
:param dims: an iterable of sum dims to contract
"""
raise NotImplementedError
@abstractmethod
def product(self, term, ordinal):
"""
Product-contract the given ``term`` along any plate dimensions
present in given ``ordinal``.
:param torch.Tensor term: the term to contract
:param frozenset ordinal: an ordinal specifying plates to contract
"""
raise NotImplementedError
def broadcast(self, term, ordinal):
"""
Broadcast the given ``term`` by expanding along any plate dimensions
present in ``ordinal`` but not ``term``.
:param torch.Tensor term: the term to expand
:param frozenset ordinal: an ordinal specifying plates
"""
dims = term._pyro_dims
missing_dims = ''.join(sorted(set(ordinal) - set(dims)))
if missing_dims:
key = 'broadcast', self._hash_by_id(term), missing_dims
if key in self._cache:
term = self._cache[key]
else:
missing_shape = tuple(self._dim_to_size[dim] for dim in missing_dims)
term = term.expand(missing_shape + term.shape)
dims = missing_dims + dims
self._cache[key] = term
term._pyro_dims = dims
return term
@abstractmethod
def inv(self, term):
"""
Computes the reciprocal of a term, for use in inclusion-exclusion.
:param torch.Tensor term: the term to invert
"""
raise NotImplementedError
def global_local(self, term, dims, ordinal):
r"""
Computes global and local terms for tensor message passing
using inclusion-exclusion::
term / sum(term, dims) * product(sum(term, dims), ordinal)
\____________________/ \_______________________________/
local part global part
:param torch.Tensor term: the term to contract
:param dims: an iterable of sum dims to contract
:param frozenset ordinal: an ordinal specifying plates to contract
:return: a tuple ``(global_part, local_part)`` as defined above
:rtype: tuple
"""
assert dims, 'dims was empty, use .product() instead'
key = 'global_local', self._hash_by_id(term), frozenset(dims), ordinal
if key in self._cache:
return self._cache[key]
term_sum = self.sumproduct([term], dims)
global_part = self.product(term_sum, ordinal)
with ignore_jit_warnings():
local_part = self.sumproduct([term, self.inv(term_sum)], set())
assert sorted(local_part._pyro_dims) == sorted(term._pyro_dims)
result = global_part, local_part
self._cache[key] = result
return result
class LinearRing(Ring):
"""
Ring of sum-product operations in linear space.
Tensor dimensions are packed; to read the name of a tensor, read the
``._pyro_dims`` attribute, which is a string of dimension names aligned
with the tensor's shape.
"""
_backend = 'torch'
def __init__(self, cache=None, dim_to_size=None):
super().__init__(cache=cache)
self._dim_to_size = {} if dim_to_size is None else dim_to_size
def sumproduct(self, terms, dims):
inputs = [term._pyro_dims for term in terms]
output = ''.join(sorted(set(''.join(inputs)) - set(dims)))
equation = ','.join(inputs) + '->' + output
term = contract(equation, *terms, backend=self._backend)
term._pyro_dims = output
return term
def product(self, term, ordinal):
dims = term._pyro_dims
for dim in sorted(ordinal, reverse=True):
pos = dims.find(dim)
if pos != -1:
key = 'product', self._hash_by_id(term), dim
if key in self._cache:
term = self._cache[key]
else:
term = term.prod(pos)
dims = dims.replace(dim, '')
self._cache[key] = term
term._pyro_dims = dims
return term
def inv(self, term):
key = 'inv', self._hash_by_id(term)
if key in self._cache:
return self._cache[key]
result = term.reciprocal()
result = result.clamp(max=torch.finfo(result.dtype).max) # avoid nan due to inf / inf
result._pyro_dims = term._pyro_dims
self._cache[key] = result
return result
class LogRing(Ring):
"""
Ring of sum-product operations in log space.
Tensor values are in log units, so ``sum`` is implemented as ``logsumexp``,
and ``product`` is implemented as ``sum``.
Tensor dimensions are packed; to read the name of a tensor, read the
``._pyro_dims`` attribute, which is a string of dimension names aligned
with the tensor's shape.
"""
_backend = 'pyro.ops.einsum.torch_log'
def __init__(self, cache=None, dim_to_size=None):
super().__init__(cache=cache)
self._dim_to_size = {} if dim_to_size is None else dim_to_size
def sumproduct(self, terms, dims):
inputs = [term._pyro_dims for term in terms]
output = ''.join(sorted(set(''.join(inputs)) - set(dims)))
equation = ','.join(inputs) + '->' + output
term = contract(equation, *terms, backend=self._backend)
term._pyro_dims = output
return term
def product(self, term, ordinal):
dims = term._pyro_dims
for dim in sorted(ordinal, reverse=True):
pos = dims.find(dim)
if pos != -1:
key = 'product', self._hash_by_id(term), dim
if key in self._cache:
term = self._cache[key]
else:
term = term.sum(pos)
dims = dims.replace(dim, '')
self._cache[key] = term
term._pyro_dims = dims
return term
def inv(self, term):
key = 'inv', self._hash_by_id(term)
if key in self._cache:
return self._cache[key]
result = -term
result = result.clamp(max=torch.finfo(result.dtype).max) # avoid nan due to inf - inf
result._pyro_dims = term._pyro_dims
self._cache[key] = result
return result
class _SampleProductBackward(Backward):
"""
Backward-sample implementation of product.
This is agnostic to sampler implementation, and hence can be used both by
:class:`MapRing` (temperature 0 sampling) and :class:`SampleRing`
(temperature 1 sampling).
"""
def __init__(self, ring, term, ordinal):
self.ring = ring
self.term = term
self.ordinal = ordinal
def process(self, message):
if message is not None:
sample_dims = message._pyro_sample_dims
message = self.ring.broadcast(message, self.ordinal)
if message._pyro_dims.index(SAMPLE_SYMBOL) != 0:
dims = SAMPLE_SYMBOL + message._pyro_dims.replace(SAMPLE_SYMBOL, '')
message = message.permute(tuple(map(message._pyro_dims.find, dims)))
message._pyro_dims = dims
assert message.dim() == len(message._pyro_dims)
message._pyro_sample_dims = sample_dims
assert message.size(0) == len(message._pyro_sample_dims)
yield self.term._pyro_backward, message
class MapRing(LogRing):
"""
Ring of forward-maxsum backward-argmax operations.
"""
_backend = 'pyro.ops.einsum.torch_map'
def product(self, term, ordinal):
result = super().product(term, ordinal)
if hasattr(term, '_pyro_backward'):
result._pyro_backward = _SampleProductBackward(self, term, ordinal)
return result
class SampleRing(LogRing):
"""
Ring of forward-sumproduct backward-sample operations in log space.
"""
_backend = 'pyro.ops.einsum.torch_sample'
def product(self, term, ordinal):
result = super().product(term, ordinal)
if hasattr(term, '_pyro_backward'):
result._pyro_backward = _SampleProductBackward(self, term, ordinal)
return result
class _MarginalProductBackward(Backward):
"""
Backward-marginal implementation of product, using inclusion-exclusion.
"""
def __init__(self, ring, term, ordinal, result):
self.ring = ring
self.term = term
self.ordinal = ordinal
self.result = weakref.ref(result)
def process(self, message):
ring = self.ring
term = self.term
result = self.result()
factors = [result]
if message is not None:
message._pyro_dims = result._pyro_dims
factors.append(message)
if term._pyro_backward.is_leaf:
product = ring.sumproduct(factors, set())
message = ring.broadcast(product, self.ordinal)
else:
factors.append(ring.inv(term))
message = ring.sumproduct(factors, set())
yield term._pyro_backward, message
class MarginalRing(LogRing):
"""
Ring of forward-sumproduct backward-marginal operations in log space.
"""
_backend = 'pyro.ops.einsum.torch_marginal'
def product(self, term, ordinal):
result = super().product(term, ordinal)
if hasattr(term, '_pyro_backward'):
result._pyro_backward = _MarginalProductBackward(self, term, ordinal, result)
return result
BACKEND_TO_RING = {
'torch': LinearRing,
'pyro.ops.einsum.torch_log': LogRing,
'pyro.ops.einsum.torch_map': MapRing,
'pyro.ops.einsum.torch_sample': SampleRing,
'pyro.ops.einsum.torch_marginal': MarginalRing,
}