def resample()

in src/sagemaker_sklearn_extension/contrib/taei/latent_space_oversampler.py [0:0]


    def resample(self, X, y, verbose=False):
        """
        Use the model and the base oversampler to generate synthetic minority samples
        """
        X, y = check_X_y(X, y)
        self.model.eval()
        X = torch.Tensor(X)
        X = X.to(self.device)
        with torch.no_grad():
            z = self.model.encode(X)
        z = z.cpu().numpy()
        if verbose:
            print(f"LatentSpaceOversampler: Shape before oversampling z:{z.shape}, y:{y.shape}")
        z_samples, y_samples = self.base_oversampler(z, y)
        if verbose:
            print(f"LatentSpaceOversampler: Shape after oversampling z:{z_samples.shape}, y:{y_samples.shape}")
        z_samples = z_samples[-(len(z_samples) - len(X)) :]
        y_samples = y_samples[-(len(y_samples) - len(y)) :].reshape(-1)
        z_samples = torch.Tensor(z_samples).to(self.device)
        with torch.no_grad():
            x_samples = self.model.decode_sample(z_samples)
        X = torch.cat([X, x_samples], dim=0).cpu().numpy()
        y = np.concatenate((y, y_samples), axis=0)
        return X, y