in src/gluonts/model/san/_estimator.py [0:0]
def create_transformation(self) -> Transformation:
transforms = []
if self.use_feat_dynamic_real:
transforms.append(
AsNumpyArray(
field=FieldName.FEAT_DYNAMIC_REAL,
expected_ndim=2,
)
)
else:
transforms.extend(
[
SetField(
output_field=FieldName.FEAT_DYNAMIC_REAL,
value=[[]]
* (self.context_length + self.prediction_length),
),
AsNumpyArray(
field=FieldName.FEAT_DYNAMIC_REAL,
expected_ndim=2,
),
# SwapAxes(input_fields=[FieldName.FEAT_DYNAMIC_REAL], axes=(0,1)),
]
)
if self.use_feat_dynamic_cat:
transforms.append(
AsNumpyArray(
field=FieldName.FEAT_DYNAMIC_CAT,
expected_ndim=2,
)
)
else:
# Manually set dummy dynamic categorical features and split by time
# Unknown issue in dataloader if leave splitting to InstanceSplitter
transforms.extend(
[
SetField(
output_field="past_" + FieldName.FEAT_DYNAMIC_CAT,
value=[[]] * self.context_length,
),
AsNumpyArray(
field="past_" + FieldName.FEAT_DYNAMIC_CAT,
expected_ndim=2,
),
SetField(
output_field="future_" + FieldName.FEAT_DYNAMIC_CAT,
value=[[]] * self.prediction_length,
),
AsNumpyArray(
field="future_" + FieldName.FEAT_DYNAMIC_CAT,
expected_ndim=2,
),
]
)
if self.use_feat_static_real:
transforms.append(
AsNumpyArray(
field=FieldName.FEAT_STATIC_REAL,
expected_ndim=1,
)
)
else:
transforms.extend(
[
SetField(
output_field=FieldName.FEAT_STATIC_REAL,
value=[],
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_REAL,
expected_ndim=1,
),
]
)
if self.use_feat_static_cat:
transforms.append(
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT,
expected_ndim=1,
)
)
transforms.extend(
[
AsNumpyArray(field=FieldName.TARGET, expected_ndim=1),
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=self.time_features,
pred_length=self.prediction_length,
),
AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
log_scale=True,
),
VstackFeatures(
output_field=FieldName.FEAT_DYNAMIC_REAL,
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
+ (
[FieldName.FEAT_DYNAMIC_REAL]
if self.use_feat_dynamic_real
else []
),
),
]
)
return Chain(transforms)