pyro/contrib/funsor/handlers/plate_messenger.py (111 lines of code) (raw):

# Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict from numbers import Number import funsor from pyro.distributions.util import copy_docs_from from pyro.poutine.broadcast_messenger import BroadcastMessenger from pyro.poutine.indep_messenger import CondIndepStackFrame from pyro.poutine.messenger import Messenger from pyro.poutine.subsample_messenger import SubsampleMessenger as OrigSubsampleMessenger from pyro.util import ignore_jit_warnings from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor from pyro.contrib.funsor.handlers.named_messenger import DimRequest, DimType, GlobalNamedMessenger funsor.set_backend("torch") class IndepMessenger(GlobalNamedMessenger): """ Vectorized plate implementation using to_data instead of _DIM_ALLOCATOR. """ def __init__(self, name=None, size=None, dim=None, indices=None): assert size > 1 assert dim is None or dim < 0 super().__init__() # without a name or dim, treat as a "vectorize" effect and allocate a non-visible dim self.dim_type = DimType.GLOBAL if name is None and dim is None else DimType.VISIBLE self.name = name if name is not None else funsor.interpreter.gensym("PLATE") self.size = size self.dim = dim if not hasattr(self, "_full_size"): self._full_size = size if indices is None: indices = funsor.ops.new_arange(funsor.tensor.get_default_prototype(), self.size) assert len(indices) == size self._indices = funsor.Tensor( indices, OrderedDict([(self.name, funsor.Bint[self.size])]), self._full_size ) def __enter__(self): super().__enter__() # do this first to take care of globals recycling name_to_dim = OrderedDict([(self.name, DimRequest(self.dim, self.dim_type))]) indices = to_data(self._indices, name_to_dim=name_to_dim) # extract the dimension allocated by to_data to match plate's current behavior self.dim, self.indices = -indices.dim(), indices.squeeze() return self def _pyro_sample(self, msg): frame = CondIndepStackFrame(self.name, self.dim, self.size, 0) msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"] def _pyro_param(self, msg): frame = CondIndepStackFrame(self.name, self.dim, self.size, 0) msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"] @copy_docs_from(OrigSubsampleMessenger) class SubsampleMessenger(IndepMessenger): def __init__(self, name=None, size=None, subsample_size=None, subsample=None, dim=None, use_cuda=None, device=None): size, subsample_size, indices = OrigSubsampleMessenger._subsample( name, size, subsample_size, subsample, use_cuda, device) self.subsample_size = subsample_size self._full_size = size self._scale = float(size) / subsample_size # initialize other things last super().__init__(name, subsample_size, dim, indices) def _pyro_sample(self, msg): super()._pyro_sample(msg) msg["scale"] = msg["scale"] * self._scale def _pyro_param(self, msg): super()._pyro_param(msg) msg["scale"] = msg["scale"] * self._scale def _subsample_site_value(self, value, event_dim=None): if self.dim is not None and event_dim is not None and self.subsample_size < self._full_size: event_shape = value.shape[len(value.shape) - event_dim:] funsor_value = to_funsor(value, output=funsor.Reals[event_shape]) if self.name in funsor_value.inputs: return to_data(funsor_value(**{self.name: self._indices})) return value def _pyro_post_param(self, msg): event_dim = msg["kwargs"].get("event_dim") new_value = self._subsample_site_value(msg["value"], event_dim) if new_value is not msg["value"]: if hasattr(msg["value"], "_pyro_unconstrained_param"): param = msg["value"]._pyro_unconstrained_param else: param = msg["value"].unconstrained() if not hasattr(param, "_pyro_subsample"): param._pyro_subsample = {} # TODO is this going to persist correctly? param._pyro_subsample[self.dim - event_dim] = self.indices new_value._pyro_unconstrained_param = param msg["value"] = new_value def _pyro_post_subsample(self, msg): event_dim = msg["kwargs"].get("event_dim") msg["value"] = self._subsample_site_value(msg["value"], event_dim) class PlateMessenger(SubsampleMessenger): """ Combines new IndepMessenger implementation with existing BroadcastMessenger. Should eventually be a drop-in replacement for pyro.plate. """ def __enter__(self): super().__enter__() return self.indices # match pyro.plate behavior def _pyro_sample(self, msg): super()._pyro_sample(msg) BroadcastMessenger._pyro_sample(msg) def __iter__(self): return iter(_SequentialPlateMessenger(self.name, self.size, self._indices.data.squeeze(), self._scale)) class _SequentialPlateMessenger(Messenger): """ Implementation of sequential plate. Should not be used directly. """ def __init__(self, name, size, indices, scale): self.name = name self.size = size self.indices = indices self._scale = scale self._counter = 0 super().__init__() def __iter__(self): with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning)]), self: self._counter = 0 for i in self.indices: self._counter += 1 yield i if isinstance(i, Number) else i.item() def _pyro_sample(self, msg): frame = CondIndepStackFrame(self.name, None, self.size, self._counter) msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"] msg["scale"] = msg["scale"] * self._scale def _pyro_param(self, msg): frame = CondIndepStackFrame(self.name, None, self.size, self._counter) msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"] msg["scale"] = msg["scale"] * self._scale