in lib/util.py [0:0]
def forward(self, x, first, state):
residual = x
x = self.pre_r_ln(x)
x, state_out = recurrent_forward(
self.r,
x,
first,
state,
reverse_lstm=self.recurrence_type == "multi_layer_bilstm" and (self.block_number + 1) % 2 == 0,
)
if self.is_residual and "lstm" in self.recurrence_type: # Transformer already residual.
x = x + residual
if self.use_pointwise_layer:
# Residual MLP
residual = x
x = self.mlp1(self.mlp0(x))
if self.is_residual:
x = x + residual
return x, state_out