in higher/patch.py [0:0]
def patched_forward(self, *args, params=None, **kwargs):
if self.direct_submodule_call:
# If submodule was called directly, run intialisation that happens
# at top level call. If *full set of params* is provided here, it
# will use those. If not, it will fall back on fast weights.
# In the future, we should be able to support passing only the
# submodule (+ children) weights here, but that's not simple.
self.root._refill_params_box(params)
with _modify_internally(self):
for name, param in zip(
self._param_names,
params_box[0][params_offset:params_offset + num_params]
):
setattr(self, name, param)
# This snippet deals with torch.nn.{RNN,GRU,LSTM}
if hasattr(self, "_flat_weights_names"):
self._flat_weights = [
self._parameters[wn] for wn in self._flat_weights_names
]
# Call true_forward after some checks
with _warnings.catch_warnings():
# If running RNNs on GPU, surpress the warnings due to flattening
# not happening here. Maybe we should raise a warning of our own?
is_RNN = isinstance(module, _torch.nn.RNNBase)
if is_RNN and _torch.cuda.is_available():
_warnings.simplefilter("ignore", category=UserWarning)
return true_forward(self, *args, **kwargs)