in pyro/nn/module.py [0:0]
def __getattr__(self, name):
# PyroParams trigger pyro.param statements.
if '_pyro_params' in self.__dict__:
_pyro_params = self.__dict__['_pyro_params']
if name in _pyro_params:
constraint, event_dim = _pyro_params[name]
unconstrained_value = getattr(self, name + "_unconstrained")
if self._pyro_context.active:
fullname = self._pyro_get_fullname(name)
if fullname in _PYRO_PARAM_STORE:
if _PYRO_PARAM_STORE._params[fullname] is not unconstrained_value:
# Update PyroModule <--- ParamStore.
unconstrained_value = _PYRO_PARAM_STORE._params[fullname]
if not isinstance(unconstrained_value, torch.nn.Parameter):
# Update PyroModule ---> ParamStore (type only; data is preserved).
unconstrained_value = torch.nn.Parameter(unconstrained_value)
_PYRO_PARAM_STORE._params[fullname] = unconstrained_value
_PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname
super().__setattr__(name + "_unconstrained", unconstrained_value)
else:
# Update PyroModule ---> ParamStore.
_PYRO_PARAM_STORE._constraints[fullname] = constraint
_PYRO_PARAM_STORE._params[fullname] = unconstrained_value
_PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname
return pyro.param(fullname, event_dim=event_dim)
else: # Cannot determine supermodule and hence cannot compute fullname.
return transform_to(constraint)(unconstrained_value)
# PyroSample trigger pyro.sample statements.
if '_pyro_samples' in self.__dict__:
_pyro_samples = self.__dict__['_pyro_samples']
if name in _pyro_samples:
prior = _pyro_samples[name]
context = self._pyro_context
if context.active:
fullname = self._pyro_get_fullname(name)
value = context.get(fullname)
if value is None:
if not hasattr(prior, "sample"): # if not a distribution
prior = prior(self)
value = pyro.sample(fullname, prior)
context.set(fullname, value)
return value
else: # Cannot determine supermodule and hence cannot compute fullname.
if not hasattr(prior, "sample"): # if not a distribution
prior = prior(self)
return prior()
result = super().__getattr__(name)
# Regular nn.Parameters trigger pyro.param statements.
if isinstance(result, torch.nn.Parameter) and not name.endswith("_unconstrained"):
if self._pyro_context.active:
pyro.param(self._pyro_get_fullname(name), result)
# Regular nn.Modules trigger pyro.module statements.
if isinstance(result, torch.nn.Module) and not isinstance(result, PyroModule):
if self._pyro_context.active:
pyro.module(self._pyro_get_fullname(name), result)
return result