in torchmoji/lstm.py [0:0]
def VariableRecurrentReverse(batch_sizes, inner):
def forward(input, hidden, weight):
output = []
input_offset = input.size(0)
last_batch_size = batch_sizes[-1]
initial_hidden = hidden
flat_hidden = not isinstance(hidden, tuple)
if flat_hidden:
hidden = (hidden,)
initial_hidden = (initial_hidden,)
hidden = tuple(h[:batch_sizes[-1]] for h in hidden)
for batch_size in reversed(batch_sizes):
inc = batch_size - last_batch_size
if inc > 0:
hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0)
for h, ih in zip(hidden, initial_hidden))
last_batch_size = batch_size
step_input = input[input_offset - batch_size:input_offset]
input_offset -= batch_size
if flat_hidden:
hidden = (inner(step_input, hidden[0], *weight),)
else:
hidden = inner(step_input, hidden, *weight)
output.append(hidden[0])
output.reverse()
output = torch.cat(output, 0)
if flat_hidden:
hidden = hidden[0]
return hidden, output
return forward