def fit()

in src/sagemaker_sklearn_extension/contrib/taei/latent_space_oversampler.py [0:0]


    def fit(self, X, y, validation_ratio=0.2, **kwargs):
        """
        Train the model using gradient descent back propagation

        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            Features matrix used to train the model
        y : vector-like of shape (n_samples, 1)
            The target vector used to train the model
        validation_ratio : float or None (default = 0.2)
            Ratio of samples to be used as validation set for early stopping in model training. If None then early
            stopping is not applied
        **kwargs:
            Additional arguments passed the the model internal fit function
        """
        X, y = check_X_y(X, y)
        if validation_ratio:
            X_train, X_validation, y_train, y_validation = train_test_split(
                X, y, test_size=validation_ratio, stratify=y, random_state=self.random_state
            )
        else:
            X_train = X
            y_train = y
            X_validation = None
            y_validation = None
        self.model.fit(
            X_train=X_train,
            y_train=y_train,
            X_validation=X_validation,
            y_validation=y_validation,
            device=self.device,
            **kwargs,
        )
        return self