in lib/util.py [0:0]
def recurrent_forward(module, x, first, state, reverse_lstm=False):
if isinstance(module, nn.LSTM):
if state is not None:
# In case recurrent models do not accept a "first" argument we zero out the hidden state here
mask = 1 - first[:, 0, None, None].to(th.float)
state = tree_map(lambda _s: _s * mask, state)
state = tree_map(lambda _s: _s.transpose(0, 1), state) # NL, B, H
if reverse_lstm:
x = th.flip(x, [1])
x, state_out = module(x, state)
if reverse_lstm:
x = th.flip(x, [1])
state_out = tree_map(lambda _s: _s.transpose(0, 1), state_out) # B, NL, H
return x, state_out
else:
return module(x, first, state)