def forward()

in lib/util.py [0:0]


    def forward(self, x, first, state):
        residual = x
        x = self.pre_r_ln(x)
        x, state_out = recurrent_forward(
            self.r,
            x,
            first,
            state,
            reverse_lstm=self.recurrence_type == "multi_layer_bilstm" and (self.block_number + 1) % 2 == 0,
        )
        if self.is_residual and "lstm" in self.recurrence_type:  # Transformer already residual.
            x = x + residual
        if self.use_pointwise_layer:
            # Residual MLP
            residual = x
            x = self.mlp1(self.mlp0(x))
            if self.is_residual:
                x = x + residual
        return x, state_out