in src/kernels/layer.py [0:0]
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
import torch.nn as nn
module_class = type(module)
layer_with_backward = (
layer if getattr(layer, "has_backward", True) else module_class
)
def train(self, mode: bool = True) -> nn.Module:
super(type(self), self).train(mode)
if mode:
self.forward = MethodType(layer_with_backward.forward, self)
else:
self.forward = MethodType(layer.forward, self)
return self
module.train = MethodType(train, module) # type: ignore[method-assign]
# Trigger setting correct forward for the current state.
module.train(module.training)