in orbit/forecaster/forecaster.py [0:0]
def set_training_data_input(self):
"""Collects data attributes into a dict for sampling/optimization api"""
# refresh a clean dict
data_input_mapper = self._model.get_data_input_mapper()
if not data_input_mapper:
raise ForecasterException("Empty or invalid data_input_mapper")
# always get standard input from training
training_meta = self.get_training_meta()
training_data_input = {
TrainingMetaKeys.RESPONSE.value.upper(): training_meta[
TrainingMetaKeys.RESPONSE.value
],
TrainingMetaKeys.RESPONSE_SD.value.upper(): training_meta[
TrainingMetaKeys.RESPONSE_SD.value
],
TrainingMetaKeys.RESPONSE_MEAN.value.upper(): training_meta[
TrainingMetaKeys.RESPONSE_MEAN.value
],
TrainingMetaKeys.NUM_OF_OBS.value.upper(): training_meta[
TrainingMetaKeys.NUM_OF_OBS.value
],
}
training_data_input = self.set_forecaster_training_meta(
data_input=training_data_input
)
if isinstance(data_input_mapper, list):
# if a list is provided, we assume an upper case in the mapper and reuse as the input value
for key in data_input_mapper:
key_lower = key.lower()
input_value = getattr(self._model, key_lower, None)
if input_value is None:
raise ForecasterException(
"{} is missing from data input".format(key_lower)
)
# stan accepts bool as int only
if isinstance(input_value, bool):
input_value = int(input_value)
training_data_input[key] = input_value
elif issubclass(data_input_mapper, Enum):
# isinstance(data_input_mapper, object):
for key in data_input_mapper:
# mapper keys in upper case; attributes is defined in lower case; need a cae casting in conversion
key_lower = key.name.lower()
input_value = getattr(self._model, key_lower, None)
if input_value is None:
raise ForecasterException(
"{} is missing from data input".format(key_lower)
)
if isinstance(input_value, bool):
# stan accepts bool as int only
input_value = int(input_value)
training_data_input[key.value] = input_value
else:
raise Exception(
"Invalid type: data_input_mapper needs to be either an Enum or list."
)
self._training_data_input = training_data_input