in kats/models/globalmodel/model.py [0:0]
def build_rnn(self) -> None:
"""Helper function for building RNN."""
params = self.params
feature_size = (
params.gmfeature.get_feature_size(params.input_window)
if params.gmfeature
else 0
)
input_size = (
params.input_window + feature_size + 2
) # two additional positions for step_num_encode and step_size_encode
len_quantile = (
0 if params.quantile is None else len(params.quantile)
) # len(params.quantile) if params.quantile is not None else 0
output_size = (
params.fcst_window * len_quantile + 1
) # one additional position for level smoothing parameter
if params.seasonality > 1:
input_size += 2 * params.seasonality
output_size += (
1 # one additional position for seasonality smoothing parameter
)
# ensure data type for jit
input_size = int(input_size)
output_size = int(output_size)
rnn = DilatedRNNStack(
params.nn_structure,
params.cell_name,
input_size,
params.state_size,
output_size,
params.h_size,
params.jit,
)
self.rnn = rnn
return