def recurrent_forward()

in lib/util.py [0:0]


def recurrent_forward(module, x, first, state, reverse_lstm=False):
    if isinstance(module, nn.LSTM):
        if state is not None:
            # In case recurrent models do not accept a "first" argument we zero out the hidden state here
            mask = 1 - first[:, 0, None, None].to(th.float)
            state = tree_map(lambda _s: _s * mask, state)
            state = tree_map(lambda _s: _s.transpose(0, 1), state)  # NL, B, H
        if reverse_lstm:
            x = th.flip(x, [1])
        x, state_out = module(x, state)
        if reverse_lstm:
            x = th.flip(x, [1])
        state_out = tree_map(lambda _s: _s.transpose(0, 1), state_out)  # B, NL, H
        return x, state_out
    else:
        return module(x, first, state)