def __init__()

in bayesmark/sklearn_funcs.py [0:0]


    def __init__(self, model, dataset, metric, shuffle_seed=0, data_root=None):
        """Build class that wraps sklearn classifier/regressor CV score for use as an objective function.

        Parameters
        ----------
        model : str
            Which classifier to use, must be key in `MODELS_CLF` or `MODELS_REG` dict depending on if dataset is
            classification or regression.
        dataset : str
            Which data set to use, must be key in `DATA_LOADERS` dict, or name of custom csv file.
        metric : str
            Which sklearn scoring metric to use, in `SCORERS_CLF` list or `SCORERS_REG` dict depending on if dataset is
            classification or regression.
        shuffle_seed : int
            Random seed to use when splitting the data into train and validation in the cross-validation splits. This
            is needed in order to keep the split constant across calls. Otherwise there would be extra noise in the
            objective function for varying splits.
        data_root : str
            Root directory to look for all custom csv files.
        """
        TestFunction.__init__(self)
        data, target, problem_type = load_data(dataset, data_root=data_root)
        assert problem_type in (ProblemType.clf, ProblemType.reg)
        self.is_classifier = problem_type == ProblemType.clf

        # Do some validation on loaded data
        assert isinstance(data, np.ndarray)
        assert isinstance(target, np.ndarray)
        assert data.ndim == 2 and target.ndim == 1
        assert data.shape[0] == target.shape[0]
        assert data.size > 0
        assert data.dtype == np.float_
        assert np.all(np.isfinite(data))  # also catch nan
        assert target.dtype == (np.int_ if self.is_classifier else np.float_)
        assert np.all(np.isfinite(target))  # also catch nan

        model_lookup = MODELS_CLF if self.is_classifier else MODELS_REG
        base_model, fixed_params, api_config = model_lookup[model]

        # New members for model
        self.base_model = base_model
        self.fixed_params = fixed_params
        self.api_config = api_config

        # Always shuffle your data to be safe. Use fixed seed for reprod.
        self.data_X, self.data_Xt, self.data_y, self.data_yt = train_test_split(
            data, target, test_size=0.2, random_state=shuffle_seed, shuffle=True
        )

        assert metric in METRICS, "Unknown metric %s" % metric
        assert metric in METRICS_LOOKUP[problem_type], "Incompatible metric %s with problem type %s" % (
            metric,
            problem_type,
        )
        self.scorer = get_scorer(SklearnModel._METRIC_MAP[metric])