in datawig/imputer.py [0:0]
def fit(self,
train_df: pd.DataFrame,
test_df: pd.DataFrame = None,
ctx: mx.context = get_context(),
learning_rate: float = 1e-3,
num_epochs: int = 100,
patience: int = 3,
test_split: float = .1,
weight_decay: float = 0.,
batch_size: int = 16,
final_fc_hidden_units: List[int] = None,
calibrate: bool = True):
"""
Trains and stores imputer model
:param train_df: training data as dataframe
:param test_df: test data as dataframe; if not provided, [test_split] % of the training
data are used as test data
:param ctx: List of mxnet contexts (if no gpu's available, defaults to [mx.cpu()])
User can also pass in a list gpus to be used, ex. [mx.gpu(0), mx.gpu(2), mx.gpu(4)]
:param learning_rate: learning rate for stochastic gradient descent (default 1e-4)
:param num_epochs: maximal number of training epochs (default 100)
:param patience: used for early stopping; after [patience] epochs with no improvement,
training is stopped. (default 3)
:param test_split: if no test_df is provided this is the ratio of test data to be held
separate for determining model convergence
:param weight_decay: regularizer (default 0)
:param batch_size: default 16
:param final_fc_hidden_units: list of dimensions for the final fully connected layer.
:param calibrate: whether to calibrate predictions
:return: trained imputer model
"""
if final_fc_hidden_units is None:
final_fc_hidden_units = []
# make sure the output directory is writable
assert os.access(self.output_path, os.W_OK), "Cannot write to directory {}".format(
self.output_path)
self.batch_size = batch_size
self.final_fc_hidden_units = final_fc_hidden_units
self.ctx = ctx
logger.debug('Using [{}] as the context for training'.format(ctx))
if (train_df is None) or (not isinstance(train_df, pd.core.frame.DataFrame)):
raise ValueError("Need a non-empty DataFrame for fitting Imputer model")
if test_df is None:
train_df, test_df = random_split(train_df, [1.0 - test_split, test_split])
iter_train, iter_test = self.__build_iterators(train_df, test_df, test_split)
self.__check_data(test_df)
# to make consecutive calls to .fit() continue where the previous call finished
if self.module is None:
self.module = self.__build_module(iter_train)
self.__fit_module(iter_train, iter_test, learning_rate, num_epochs, patience, weight_decay)
# Check whether calibration is needed, if so ompute and set internal parameter
# for temperature scaling that is supplied to self.__predict_mxnet_iter()
if calibrate is True:
self.calibrate(iter_test)
_, metrics = self.__transform_and_compute_metrics_mxnet_iter(iter_test,
metrics_path=self.metrics_path)
for att, att_metric in metrics.items():
if isinstance(att_metric, dict) and ('precision_recall_curves' in att_metric):
self.precision_recall_curves[att] = att_metric['precision_recall_curves']
self.__prune_models()
self.save()
if self.is_explainable:
self.__persist_class_prototypes(iter_train, train_df)
self.__close_filehandlers()
return self