in src/gluonts/model/tft/_estimator.py [0:0]
def create_transformation(self) -> Transformation:
transforms = (
[AsNumpyArray(field=FieldName.TARGET, expected_ndim=1)]
+ (
[
AsNumpyArray(field=name, expected_ndim=1)
for name in self.static_cardinalities.keys()
]
)
+ [
AsNumpyArray(field=name, expected_ndim=1)
for name in chain(
self.static_feature_dims.keys(),
self.dynamic_cardinalities.keys(),
)
]
+ [
AsNumpyArray(field=name, expected_ndim=2)
for name in self.dynamic_feature_dims.keys()
]
+ [
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=self.time_features,
pred_length=self.prediction_length,
),
]
)
if self.static_cardinalities:
transforms.append(
VstackFeatures(
output_field=FieldName.FEAT_STATIC_CAT,
input_fields=list(self.static_cardinalities.keys()),
h_stack=True,
)
)
else:
transforms.extend(
[
SetField(
output_field=FieldName.FEAT_STATIC_CAT,
value=[0.0],
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT, expected_ndim=1
),
]
)
if self.static_feature_dims:
transforms.append(
VstackFeatures(
output_field=FieldName.FEAT_STATIC_REAL,
input_fields=list(self.static_feature_dims.keys()),
h_stack=True,
)
)
else:
transforms.extend(
[
SetField(
output_field=FieldName.FEAT_STATIC_REAL,
value=[0.0],
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_REAL, expected_ndim=1
),
]
)
if self.dynamic_cardinalities:
transforms.append(
VstackFeatures(
output_field=FieldName.FEAT_DYNAMIC_CAT,
input_fields=list(self.dynamic_cardinalities.keys()),
)
)
else:
transforms.extend(
[
SetField(
output_field=FieldName.FEAT_DYNAMIC_CAT,
value=[[0.0]],
),
AsNumpyArray(
field=FieldName.FEAT_DYNAMIC_CAT,
expected_ndim=2,
),
BroadcastTo(
field=FieldName.FEAT_DYNAMIC_CAT,
ext_length=self.prediction_length,
),
]
)
input_fields = [FieldName.FEAT_TIME]
if self.dynamic_feature_dims:
input_fields += list(self.dynamic_feature_dims.keys())
transforms.append(
VstackFeatures(
input_fields=input_fields,
output_field=FieldName.FEAT_DYNAMIC_REAL,
)
)
if self.past_dynamic_cardinalities:
transforms.append(
VstackFeatures(
output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat",
input_fields=list(self.past_dynamic_cardinalities.keys()),
)
)
else:
transforms.extend(
[
SetField(
output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat",
value=[[0.0]],
),
AsNumpyArray(
field=FieldName.PAST_FEAT_DYNAMIC + "_cat",
expected_ndim=2,
),
BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC + "_cat"),
]
)
if self.past_dynamic_feature_dims:
transforms.append(
VstackFeatures(
output_field=FieldName.PAST_FEAT_DYNAMIC_REAL,
input_fields=list(self.past_dynamic_feature_dims.keys()),
)
)
else:
transforms.extend(
[
SetField(
output_field=FieldName.PAST_FEAT_DYNAMIC_REAL,
value=[[0.0]],
),
AsNumpyArray(
field=FieldName.PAST_FEAT_DYNAMIC_REAL, expected_ndim=2
),
BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC_REAL),
]
)
return Chain(transforms)