in pyro/nn/module.py [0:0]
def __setattr__(self, name, value):
if isinstance(value, PyroModule):
# Create a new sub PyroModule, overwriting any old value.
try:
delattr(self, name)
except AttributeError:
pass
self.add_module(name, value)
return
if isinstance(value, PyroParam):
# Create a new PyroParam, overwriting any old value.
try:
delattr(self, name)
except AttributeError:
pass
constrained_value, constraint, event_dim = value
self._pyro_params[name] = constraint, event_dim
if self._pyro_context.active:
fullname = self._pyro_get_fullname(name)
pyro.param(fullname, constrained_value, constraint=constraint, event_dim=event_dim)
constrained_value = pyro.param(fullname)
unconstrained_value = constrained_value.unconstrained()
if not isinstance(unconstrained_value, torch.nn.Parameter):
# Update PyroModule ---> ParamStore (type only; data is preserved).
unconstrained_value = torch.nn.Parameter(unconstrained_value)
_PYRO_PARAM_STORE._params[fullname] = unconstrained_value
_PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname
else: # Cannot determine supermodule and hence cannot compute fullname.
unconstrained_value = _unconstrain(constrained_value, constraint)
super().__setattr__(name + "_unconstrained", unconstrained_value)
return
if isinstance(value, torch.nn.Parameter):
# Create a new nn.Parameter, overwriting any old value.
try:
delattr(self, name)
except AttributeError:
pass
if self._pyro_context.active:
fullname = self._pyro_get_fullname(name)
value = pyro.param(fullname, value)
if not isinstance(value, torch.nn.Parameter):
# Update PyroModule ---> ParamStore (type only; data is preserved).
value = torch.nn.Parameter(value)
_PYRO_PARAM_STORE._params[fullname] = value
_PYRO_PARAM_STORE._param_to_name[value] = fullname
super().__setattr__(name, value)
return
if isinstance(value, torch.Tensor):
if name in self._pyro_params:
# Update value of an existing PyroParam.
constraint, event_dim = self._pyro_params[name]
unconstrained_value = getattr(self, name + "_unconstrained")
with torch.no_grad():
unconstrained_value.data = transform_to(constraint).inv(value.detach())
return
if isinstance(value, PyroSample):
# Create a new PyroSample, overwriting any old value.
try:
delattr(self, name)
except AttributeError:
pass
_pyro_samples = self.__dict__['_pyro_samples']
_pyro_samples[name] = value.prior
return
super().__setattr__(name, value)