def patched_forward()

in higher/patch.py [0:0]


    def patched_forward(self, *args, params=None, **kwargs):
        if self.direct_submodule_call:
            # If submodule was called directly, run intialisation that happens
            # at top level call. If *full set of params* is provided here, it 
            # will use those. If not, it will fall back on fast weights.
            # In the future, we should be able to support passing only the 
            # submodule (+ children) weights here, but that's not simple.
            self.root._refill_params_box(params)

        with _modify_internally(self):
            for name, param in zip(
                self._param_names,
                params_box[0][params_offset:params_offset + num_params]
            ):
                setattr(self, name, param)

            # This snippet deals with torch.nn.{RNN,GRU,LSTM}
            if hasattr(self, "_flat_weights_names"):
                self._flat_weights = [
                    self._parameters[wn] for wn in self._flat_weights_names
                ]

        # Call true_forward after some checks
        with _warnings.catch_warnings():

            # If running RNNs on GPU, surpress the warnings due to flattening
            # not happening here. Maybe we should raise a warning of our own?
            is_RNN = isinstance(module, _torch.nn.RNNBase)
            if is_RNN and _torch.cuda.is_available():
                _warnings.simplefilter("ignore", category=UserWarning)
            
            return true_forward(self, *args, **kwargs)