in torchmoji/lstm.py [0:0]
def VariableRecurrent(batch_sizes, inner):
def forward(input, hidden, weight):
output = []
input_offset = 0
last_batch_size = batch_sizes[0]
hiddens = []
flat_hidden = not isinstance(hidden, tuple)
if flat_hidden:
hidden = (hidden,)
for batch_size in batch_sizes:
step_input = input[input_offset:input_offset + batch_size]
input_offset += batch_size
dec = last_batch_size - batch_size
if dec > 0:
hiddens.append(tuple(h[-dec:] for h in hidden))
hidden = tuple(h[:-dec] for h in hidden)
last_batch_size = batch_size
if flat_hidden:
hidden = (inner(step_input, hidden[0], *weight),)
else:
hidden = inner(step_input, hidden, *weight)
output.append(hidden[0])
hiddens.append(hidden)
hiddens.reverse()
hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens))
assert hidden[0].size(0) == batch_sizes[0]
if flat_hidden:
hidden = hidden[0]
output = torch.cat(output, 0)
return hidden, output
return forward