in kats/models/globalmodel/utils.py [0:0]
def forward(self, input_t: Tensor) -> Tensor:
"""Forward method of DilatedRNNStack
Args:
input_t: A `torch.Tensor` object representing input features of shape (batch_size, input_size).
Returns:
A `torch.Tensor` object representing outputs of shape (batch_size, output_size).
"""
prev_block_output = torch.zeros(
input_t.shape[0], self.out_size, dtype=torch.float
)
t = len(self.h_state_store)
self.h_state_store.append([])
self.c_state_store.append([])
output_t = NoneT # just to initialize output_t
has_prev_state = t > 0
layer = 0
for iblock in range(self.block_num):
for lay in range(len(self.nn_structure[iblock])):
if lay == 0:
if iblock == 0:
tmp_input = input_t
else:
tmp_input = prev_block_output
else:
tmp_input = output_t
ti_1 = t - self.nn_structure[iblock][lay]
has_delayed_state = ti_1 >= 0
if self.cell_name == "S2Cell":
output_t, (h_state, new_state) = self._forward_S2Cell(
tmp_input, layer, has_prev_state, has_delayed_state, t, ti_1
)
elif self.cell_name == "LSTM2Cell":
output_t, (h_state, new_state) = self._forward_LSTM2Cell(
tmp_input, layer, has_prev_state, has_delayed_state, t, ti_1
)
else: # LSTM
if has_delayed_state:
h_state, new_state = self.cells[layer](
tmp_input,
(
self.h_state_store[ti_1][layer],
self.c_state_store[ti_1][layer],
),
)
elif has_prev_state:
h_state, new_state = self.cells[layer](
tmp_input,
(
self.h_state_store[t - 1][layer],
self.c_state_store[t - 1][layer],
),
)
else:
h_state, new_state = self.cells[layer](tmp_input)
output_t = h_state
self.h_state_store[t].append(h_state)
self.c_state_store[t].append(new_state)
layer += 1
prev_block_output = output_t + prev_block_output
if self.adaptor is not None:
output_t = self.adaptor(prev_block_output)
else:
output_t = prev_block_output
return output_t