in datawig/imputer.py [0:0]
def __init__(self,
data_encoders: List[ColumnEncoder],
data_featurizers: List[Featurizer],
label_encoders: List[ColumnEncoder],
output_path="") -> None:
self.ctx = None
self.module = None
self.data_encoders = data_encoders
self.batch_size = 16
self.data_featurizers = data_featurizers
self.label_encoders = label_encoders
self.final_fc_hidden_units = []
self.train_losses = None
self.test_losses = None
self.training_time = 0.
self.calibration_temperature = None
self.precision_recall_curves = {}
self.calibration_info = {}
self.__class_patterns = None
# explainability only works for Categorical and Tfidf inputs with a single categorical output column
self.is_explainable = np.any([isinstance(encoder, CategoricalEncoder) or isinstance(encoder, TfIdfEncoder)
for encoder in self.data_encoders]) and \
(len(self.label_encoders) == 1) and \
(isinstance(self.label_encoders[0], CategoricalEncoder))
if len(self.data_featurizers) != len(self.data_encoders):
raise ValueError("Argument Number of data_featurizers ({}) \
must match number of data_encoders ({})".format(len(self.data_encoders), len(self.data_featurizers)))
for encoder in self.data_encoders:
encoder_type = type(encoder)
if not issubclass(encoder_type, ColumnEncoder):
raise ValueError("Arguments passed as data_encoder must be valid " +
"datawig.column_encoders.ColumnEncoder, was {}".format(
encoder_type))
for encoder in self.label_encoders:
encoder_type = type(encoder)
if encoder_type not in [CategoricalEncoder, NumericalEncoder]:
raise ValueError("Arguments passed as label_columns must be \
datawig.column_encoders.CategoricalEncoder or NumericalEncoder, \
was {}".format(encoder_type))
encoder_outputs = [encoder.output_column for encoder in self.data_encoders]
for featurizer in self.data_featurizers:
featurizer_type = type(featurizer)
if not issubclass(featurizer_type, Featurizer):
raise ValueError("Arguments passed as data_featurizers must be valid \
datawig.mxnet_input_symbols.Featurizer type, \
was {}".format(featurizer_type))
if featurizer.field_name not in encoder_outputs:
raise ValueError(
"List of encoder outputs [{}] does not contain featurizer input for {}".format(
", ".join(encoder_outputs), featurizer_type))
# TODO: check whether encoder type matches requirements of featurizer
# collect names of data and label columns
input_col_names = [c.field_name for c in self.data_featurizers]
label_col_names = list(itertools.chain(*[c.input_columns for c in self.label_encoders]))
if len(set(input_col_names).intersection(set(label_col_names))) != 0:
raise ValueError("cannot train with label columns that are in the input")
# if there is no output directory provided, try to write to current dir
if (output_path == '') or (not output_path):
output_path = '.'
self.output_path = output_path
# if there was no output dir provided, name it to the label (-list) fitted
if self.output_path == ".":
label_names = [c.output_column.lower().replace(" ", "_") for c in self.label_encoders]
self.output_path = "-".join(label_names)
if not os.path.exists(self.output_path):
os.makedirs(self.output_path)
self.__attach_log_filehandler(filename=os.path.join(self.output_path, 'imputer.log'))
self.module_path = os.path.join(self.output_path, "model")
self.metrics_path = os.path.join(self.output_path, "fit-test-metrics.json")