in src/gluonts/model/seq2seq/_mq_dnn_estimator.py [0:0]
def from_inputs(cls, train_iter, **params):
logger = logging.getLogger(__name__)
logger.info(
f"gluonts[from_inputs]: User supplied params set to {params}"
)
# auto_params usually include `use_feat_dynamic_real`, `use_past_feat_dynamic_real`,
# `use_feat_static_cat` and `cardinality`
auto_params = cls.derive_auto_fields(train_iter)
fields = [
"use_feat_dynamic_real",
"use_past_feat_dynamic_real",
"use_feat_static_cat",
]
# user defined arguments become implications
for field in fields:
if field in params.keys():
is_params_field = (
params[field]
if type(params[field]) == bool
else strtobool(params[field])
)
if is_params_field and not auto_params[field]:
logger.warning(
f"gluonts[from_inputs]: {field} set to False since it is not present in the data."
)
params[field] = False
if field == "use_feat_static_cat":
params["cardinality"] = None
elif (
field == "use_feat_static_cat"
and not is_params_field
and auto_params[field]
):
params["cardinality"] = None
# user specified 'params' will take precedence:
params = {**auto_params, **params}
logger.info(
f"gluonts[from_inputs]: use_past_feat_dynamic_real set to "
f"'{params['use_past_feat_dynamic_real']}', use_feat_dynamic_real set to "
f"'{params['use_feat_dynamic_real']}', and use_feat_static_cat set to "
f"'{params['use_feat_static_cat']}' with cardinality of '{params['cardinality']}'"
)
return cls.from_hyperparameters(**params)