in src/gluonts/model/seq2seq/_forking_estimator.py [0:0]
def create_transformation(self) -> Transformation:
chain = []
dynamic_feat_fields = []
remove_field_names = [
FieldName.FEAT_DYNAMIC_CAT,
FieldName.FEAT_STATIC_REAL,
]
# --- GENERAL TRANSFORMATION CHAIN ---
# determine unused input
if not self.use_past_feat_dynamic_real:
remove_field_names.append(FieldName.PAST_FEAT_DYNAMIC_REAL)
if not self.use_feat_dynamic_real:
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
if not self.use_feat_static_cat:
remove_field_names.append(FieldName.FEAT_STATIC_CAT)
chain.extend(
[
RemoveFields(field_names=remove_field_names),
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
dtype=self.dtype,
),
]
)
# --- TRANSFORMATION CHAIN FOR DYNAMIC FEATURES ---
if self.add_time_feature:
chain.append(
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=time_features_from_frequency_str(self.freq),
pred_length=self.prediction_length,
dtype=self.dtype,
)
)
dynamic_feat_fields.append(FieldName.FEAT_TIME)
if self.add_age_feature:
chain.append(
AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
dtype=self.dtype,
)
)
dynamic_feat_fields.append(FieldName.FEAT_AGE)
if self.use_feat_dynamic_real:
# Backwards compatibility:
chain.append(
RenameFields({"dynamic_feat": FieldName.FEAT_DYNAMIC_REAL})
)
dynamic_feat_fields.append(FieldName.FEAT_DYNAMIC_REAL)
# we need to make sure that there is always some dynamic input
# we will however disregard it in the hybrid forward.
# the time feature is empty for yearly freq so also adding a dummy feature
# in the case that the time feature is the only one on
if len(dynamic_feat_fields) == 0 or (
not self.add_age_feature
and not self.use_feat_dynamic_real
and self.freq == "Y"
):
chain.append(
AddConstFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_CONST,
pred_length=self.prediction_length,
const=0.0, # For consistency in case with no dynamic features
dtype=self.dtype,
)
)
dynamic_feat_fields.append(FieldName.FEAT_CONST)
# now we map all the dynamic input of length context_length + prediction_length onto FieldName.FEAT_DYNAMIC
# we exclude past_feat_dynamic_real since its length is only context_length
if len(dynamic_feat_fields) > 1:
chain.append(
VstackFeatures(
output_field=FieldName.FEAT_DYNAMIC,
input_fields=dynamic_feat_fields,
)
)
elif len(dynamic_feat_fields) == 1:
chain.append(
RenameFields({dynamic_feat_fields[0]: FieldName.FEAT_DYNAMIC})
)
# --- TRANSFORMATION CHAIN FOR STATIC FEATURES ---
if not self.use_feat_static_cat:
chain.append(
SetField(
output_field=FieldName.FEAT_STATIC_CAT,
value=np.array([0], dtype=np.int32),
)
)
return Chain(chain)