def _replace_forward()

in src/kernels/layer.py [0:0]


def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
    import torch.nn as nn

    module_class = type(module)
    layer_with_backward = (
        layer if getattr(layer, "has_backward", True) else module_class
    )

    def train(self, mode: bool = True) -> nn.Module:
        super(type(self), self).train(mode)
        if mode:
            self.forward = MethodType(layer_with_backward.forward, self)
        else:
            self.forward = MethodType(layer.forward, self)
        return self

    module.train = MethodType(train, module)  # type: ignore[method-assign]

    # Trigger setting correct forward for the current state.
    module.train(module.training)