in src/sagemaker_sklearn_extension/contrib/taei/latent_space_oversampler.py [0:0]
def resample(self, X, y, verbose=False):
"""
Use the model and the base oversampler to generate synthetic minority samples
"""
X, y = check_X_y(X, y)
self.model.eval()
X = torch.Tensor(X)
X = X.to(self.device)
with torch.no_grad():
z = self.model.encode(X)
z = z.cpu().numpy()
if verbose:
print(f"LatentSpaceOversampler: Shape before oversampling z:{z.shape}, y:{y.shape}")
z_samples, y_samples = self.base_oversampler(z, y)
if verbose:
print(f"LatentSpaceOversampler: Shape after oversampling z:{z_samples.shape}, y:{y_samples.shape}")
z_samples = z_samples[-(len(z_samples) - len(X)) :]
y_samples = y_samples[-(len(y_samples) - len(y)) :].reshape(-1)
z_samples = torch.Tensor(z_samples).to(self.device)
with torch.no_grad():
x_samples = self.model.decode_sample(z_samples)
X = torch.cat([X, x_samples], dim=0).cpu().numpy()
y = np.concatenate((y, y_samples), axis=0)
return X, y