in src/gluonts/model/seq2seq/_forking_estimator.py [0:0]
def _create_instance_splitter(self, mode: str):
assert mode in ["training", "validation", "test"]
instance_sampler = {
"training": self.train_sampler,
"validation": self.validation_sampler,
"test": TestSplitSampler(),
}[mode]
chain = []
chain.append(
# because of how the forking decoder works, every time step
# in context is used for splitting, which is why we use the TestSplitSampler
ForkingSequenceSplitter(
instance_sampler=instance_sampler,
enc_len=self.context_length,
dec_len=self.prediction_length,
num_forking=self.num_forking,
encoder_series_fields=[
FieldName.OBSERVED_VALUES,
# RTS with past and future values which is never empty because added dummy constant variable
FieldName.FEAT_DYNAMIC,
]
+ (
# RTS with only past values are only used by the encoder
[FieldName.PAST_FEAT_DYNAMIC_REAL]
if self.use_past_feat_dynamic_real
else []
),
encoder_disabled_fields=(
[FieldName.FEAT_DYNAMIC]
if not self.enable_encoder_dynamic_feature
else []
)
+ (
[FieldName.PAST_FEAT_DYNAMIC_REAL]
if not self.enable_encoder_dynamic_feature
and self.use_past_feat_dynamic_real
else []
),
decoder_series_fields=[
# Decoder will use all fields under FEAT_DYNAMIC which are the RTS with past and future values
FieldName.FEAT_DYNAMIC,
]
+ ([FieldName.OBSERVED_VALUES] if mode != "test" else []),
decoder_disabled_fields=(
[FieldName.FEAT_DYNAMIC]
if not self.enable_decoder_dynamic_feature
else []
),
prediction_time_decoder_exclude=[FieldName.OBSERVED_VALUES],
)
)
# past_feat_dynamic features generated above in ForkingSequenceSplitter from those under feat_dynamic - we need
# to stack with the other short related time series from the system labeled as past_past_feat_dynamic_real.
# The system labels them as past_feat_dynamic_real and the additional past_ is added to the string
# in the ForkingSequenceSplitter
if self.use_past_feat_dynamic_real:
# Stack features from ForkingSequenceSplitter horizontally since they were transposed
# so shape is now (enc_len, num_past_feature_dynamic)
chain.append(
VstackFeatures(
output_field=FieldName.PAST_FEAT_DYNAMIC,
input_fields=[
"past_" + FieldName.PAST_FEAT_DYNAMIC_REAL,
FieldName.PAST_FEAT_DYNAMIC,
],
h_stack=True,
)
)
return Chain(chain)