in src/peft/utils/incremental_pca.py [0:0]
def partial_fit(self, X, check_input=True):
"""
Incrementally fits the model with batch data `X`.
Args:
X (torch.Tensor): The batch input data tensor with shape (n_samples, n_features).
check_input (bool, optional): If True, validates the input. Defaults to True.
Returns:
IncrementalPCA: The updated IPCA model after processing the batch.
"""
first_pass = not hasattr(self, "components_")
if check_input:
X = self._validate_data(X)
n_samples, n_features = X.shape
# Initialize attributes to avoid errors during the first call to partial_fit
if first_pass:
self.mean_ = None # Will be initialized properly in _incremental_mean_and_var based on data dimensions
self.var_ = None # Will be initialized properly in _incremental_mean_and_var based on data dimensions
self.n_samples_seen_ = torch.tensor([0], device=X.device)
self.n_features_ = n_features
if not self.n_components:
self.n_components = min(n_samples, n_features)
if n_features != self.n_features_:
raise ValueError(
"Number of features of the new batch does not match the number of features of the first batch."
)
col_mean, col_var, n_total_samples = self._incremental_mean_and_var(
X, self.mean_, self.var_, self.n_samples_seen_
)
if first_pass:
X -= col_mean
else:
col_batch_mean = torch.mean(X, dim=0)
X -= col_batch_mean
mean_correction_factor = torch.sqrt((self.n_samples_seen_.double() / n_total_samples) * n_samples)
mean_correction = mean_correction_factor * (self.mean_ - col_batch_mean)
X = torch.vstack(
(
self.singular_values_.view((-1, 1)) * self.components_,
X,
mean_correction,
)
)
if self.lowrank:
U, S, Vt = self._svd_fn_lowrank(X)
else:
U, S, Vt = self._svd_fn_full(X)
U, Vt = self._svd_flip(U, Vt, u_based_decision=False)
explained_variance = S**2 / (n_total_samples - 1)
explained_variance_ratio = S**2 / torch.sum(col_var * n_total_samples)
self.n_samples_seen_ = n_total_samples
self.components_ = Vt[: self.n_components]
self.singular_values_ = S[: self.n_components]
self.mean_ = col_mean
self.var_ = col_var
self.explained_variance_ = explained_variance[: self.n_components]
self.explained_variance_ratio_ = explained_variance_ratio[: self.n_components]
if self.n_components not in (n_samples, n_features):
self.noise_variance_ = explained_variance[self.n_components :].mean()
else:
self.noise_variance_ = torch.tensor(0.0, device=X.device)
return self