in pytorch_translate/common_layers.py [0:0]
def forward(self, x, hidden, batch_size_per_step):
self.batch_size_per_step = batch_size_per_step
self.starting_batch_size = (
batch_size_per_step[-1] if self.reverse else batch_size_per_step[0]
)
output = []
input_offset = x.size(0) if self.reverse else 0
hiddens = []
flat_hidden = not isinstance(hidden, tuple)
if flat_hidden:
hidden = (hidden,)
initial_hidden = hidden
if self.reverse:
hidden = tuple(h[: self.batch_size_per_step[-1]] for h in hidden)
last_batch_size = self.starting_batch_size
# Iterate over time steps with varying batch_size
for i in range(len(self.batch_size_per_step)):
if self.reverse:
step_batch_size = self.batch_size_per_step[-1 - i]
step_input = x[(input_offset - step_batch_size) : input_offset]
input_offset -= step_batch_size
else:
step_batch_size = self.batch_size_per_step[i]
step_input = x[input_offset : (input_offset + step_batch_size)]
input_offset += step_batch_size
new_pads = last_batch_size - step_batch_size
if new_pads > 0:
# First slice out the pieces for pads
hiddens.insert(0, tuple(h[-new_pads:] for h in hidden))
# Only pass the non-pad part of hidden states
hidden = tuple(h[:-new_pads] for h in hidden)
if new_pads < 0:
hidden = tuple(
torch.cat((h, ih[last_batch_size:step_batch_size]), 0)
for h, ih in zip(hidden, initial_hidden)
)
last_batch_size = step_batch_size
if flat_hidden:
hidden = (self.rnn_cell(step_input, hidden[0]),)
else:
hidden = self.rnn_cell(step_input, hidden)
output.append(hidden[0])
if not self.reverse:
hiddens.insert(0, hidden)
hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens))
assert output[0].size(0) == self.starting_batch_size
if flat_hidden:
hidden = hidden[0]
if self.reverse:
output.reverse()
output = torch.cat(output, 0)
return hidden, output