in apex/apex/RNN/RNNBackend.py [0:0]
def forward(self, input, collect_hidden=False, reverse=False):
"""
forward()
"""
seq_len = input.size(0)
bsz = input.size(1)
inp_iter = reversed(range(seq_len)) if reverse else range(seq_len)
hidden_states = [[] for i in range(self.nLayers)]
outputs = []
for seq in inp_iter:
for layer in range(self.nLayers):
if layer == 0:
prev_out = input[seq]
outs = self.rnns[layer](prev_out)
if collect_hidden:
hidden_states[layer].append(outs)
elif seq == seq_len-1:
hidden_states[layer].append(outs)
prev_out = outs[0]
outputs.append(prev_out)
if reverse:
outputs = list(reversed(outputs))
'''
At this point outputs is in format:
list( [seq_length] x Tensor([bsz][features]) )
need to convert it to:
list( Tensor([seq_length][bsz][features]) )
'''
output = flatten_list(outputs)
'''
hidden_states at this point is in format:
list( [layer][seq_length][hidden_states] x Tensor([bsz][features]) )
need to convert it to:
For not collect hidden:
list( [hidden_states] x Tensor([layer][bsz][features]) )
For collect hidden:
list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )
'''
if not collect_hidden:
seq_len = 1
n_hid = self.rnns[0].n_hidden_states
new_hidden = [ [ [ None for k in range(self.nLayers)] for j in range(seq_len) ] for i in range(n_hid) ]
for i in range(n_hid):
for j in range(seq_len):
for k in range(self.nLayers):
new_hidden[i][j][k] = hidden_states[k][j][i]
hidden_states = new_hidden
#Now in format list( [hidden_states][seq_length][layer] x Tensor([bsz][features]) )
#Reverse seq_length if reverse
if reverse:
hidden_states = list( list(reversed(list(entry))) for entry in hidden_states)
#flatten layer dimension into tensor
hiddens = list( list(
flatten_list(seq) for seq in hidden )
for hidden in hidden_states )
#Now in format list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )
#Remove seq_length dimension if not collect_hidden
if not collect_hidden:
hidden_states = list( entry[0] for entry in hidden_states)
return output, hidden_states