in higher/patch.py [0:0]
def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
if name in d:
del d[name]
params = self.__dict__.get('_parameters')
if params is not None and name in params:
if not isinstance(value, _torch.Tensor):
raise TypeError("Require Tensor as fast weights. "
"Got {}".format(_torch.typename(value)))
if not self._being_modified_internally:
# Additional behaviour for when fast weights are being
# directly modified goes here:
old_value = self._parameters[name]
fast_params = self.root.fast_params[:]
if not fast_params:
raise Exception(
"Cannot assign parameters to patched module which "
"does not have implicit fast parameters."
)
replacement_index = _utils._find_param_in_list(
old_value, fast_params
)
fast_params[replacement_index] = value
self.update_params(fast_params)
# Change parameters in place, usually during boxed_forward pass
self._parameters[name] = value
else:
modules = self.__dict__.get('_modules')
if isinstance(value, _torch.nn.Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() "
"call"
)
remove_from(self.__dict__, self._parameters, self._buffers)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError(
(
"cannot assign '{}' "
"as child module '{}'"
"(torch.nn.Module or None expected)"
).format(_torch.typename(value), name)
)
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if value is not None and not isinstance(
value, _torch.Tensor
):
raise TypeError(
"cannot assign '{}' as buffer '{}' "
"(torch.Tensor or None expected)".format(
_torch.typename(value), name
)
)
buffers[name] = value
else:
object.__setattr__(self, name, value)