in src/gluonts/nursery/SCott/pts/model/deepvar/deepvar_estimator.py [0:0]
def create_transformation(self) -> Transformation:
def use_marginal_transformation(
marginal_transformation: bool,
) -> Transformation:
if marginal_transformation:
return CDFtoGaussianTransform(
target_field=FieldName.TARGET,
observed_values_field=FieldName.OBSERVED_VALUES,
max_context_length=self.conditioning_length,
target_dim=self.target_dim,
)
else:
return RenameFields(
{
f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf",
f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf",
}
)
remove_field_names = [FieldName.FEAT_DYNAMIC_CAT]
if not self.use_feat_dynamic_real:
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
if not self.use_feat_static_real:
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
return Chain(
[RemoveFields(field_names=remove_field_names)]
+ (
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
if not self.use_feat_static_cat
else []
)
+ (
[
SetField(
output_field=FieldName.FEAT_STATIC_REAL, value=[0.0]
)
]
if not self.use_feat_static_real
else []
)
+ [
AsNumpyArray(
field=FieldName.TARGET,
expected_ndim=1 + len(self.distr_output.event_shape),
),
# maps the target to (1, T)
# if the target data is uni dimensional
ExpandDimArray(
field=FieldName.TARGET,
axis=0 if self.distr_output.event_shape[0] == 1 else None,
),
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=self.time_features,
pred_length=self.prediction_length,
),
AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
log_scale=True,
dtype=self.dtype,
),
VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
+ (
[FieldName.FEAT_DYNAMIC_REAL]
if self.use_feat_dynamic_real
else []
),
),
TargetDimIndicator(
field_name="target_dimension_indicator",
target_field=FieldName.TARGET,
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT,
expected_ndim=1,
dtype=np.long,
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_REAL, expected_ndim=1
),
InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
train_sampler=ExpectedNumInstanceSampler(num_instances=1),
past_length=self.history_length,
future_length=self.prediction_length,
time_series_fields=[
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
pick_incomplete=self.pick_incomplete,
),
use_marginal_transformation(self.use_marginal_transformation),
]
)