def iterations_mom_sketch_proj()

in ridge_sketch.py [0:0]


    def iterations_mom_sketch_proj(self, b, w, sketch_method, m):
        # Initialize residual to A @ w_0 - b
        w_previous = w.copy()
        r = -b.copy()
        r_previous = r.copy()
        r_norm_initial = np.linalg.norm(r)
        r_norm = r_norm_initial
        self.residual_norms.append(1.0)

        if self.verbose:
            print()

        mom_eta_cst = self.mom_eta
        mom_eta_k = mom_eta_cst
        mom_lmbda_k = 0.0
        switch_flag = False  # for sequence of etas
        for i in range(self.max_iter):
            if self.verbose and (i % 1000 == 0):
                print(f"iter:{i:^8d} |  rel res norm: {(r_norm / r_norm_initial):.2e}")

            if self.mom_eta is not None:
                # if self.eta set by user,
                # update momentum parameters using our settings

                if self.use_heuristic:
                    # # HEURISTIC: unitary step size and capped beta to 1/2
                    # self.step_size = 1.
                    # self.mom_beta = min(1 - ((2 - self.mom_eta)/mom_zeta),.5)

                    # HEURISTIC: sequence of eta and unitary step size
                    if self.step_size is not None and self.mom_beta is not None:
                        if self.mom_beta >= 0.5 and not switch_flag:
                            switch_flag = True
                            mom_eta_cst = 1.0

                    # Update eta to its current constant value
                    mom_eta_k_plus_1 = mom_eta_cst

                    # Update lmbda
                    mom_lmbda_k_plus_1 = (
                        mom_eta_k * (1.0 + mom_lmbda_k - mom_eta_k) / mom_eta_k_plus_1
                    )

                    # Update step size and beta
                    self.step_size = 1.0  # enforce unitary step size
                    self.mom_beta = mom_lmbda_k / (mom_lmbda_k_plus_1 + 1.0)

                    mom_eta_k = mom_eta_k_plus_1
                    mom_lmbda_k = mom_lmbda_k_plus_1
                else:
                    # THEORY: constant eta
                    # mom_zeta_plus_1 = (i + 1) * (1 - self.mom_eta) + 1
                    # self.step_size = self.mom_eta / mom_zeta_plus_1
                    # self.mom_beta = 1 - ((2 - self.mom_eta) / mom_zeta_plus_1)

                    # THEORY: sequence of etas (constant value < 1, then 1)
                    if self.step_size is not None and self.mom_beta is not None:
                        if self.step_size <= 0.5 and not switch_flag:
                            switch_flag = True
                            mom_eta_cst = 1.0

                    # Update eta to its current constant value
                    mom_eta_k_plus_1 = mom_eta_cst

                    # Update lmbda
                    mom_lmbda_k_plus_1 = (
                        mom_eta_k * (1.0 + mom_lmbda_k - mom_eta_k) / mom_eta_k_plus_1
                    )

                    # Update step size and beta
                    self.step_size = mom_eta_k / (mom_lmbda_k_plus_1 + 1.0)
                    self.mom_beta = mom_lmbda_k / (mom_lmbda_k_plus_1 + 1.0)

                    mom_eta_k = mom_eta_k_plus_1
                    mom_lmbda_k = mom_lmbda_k_plus_1

            SA, SAS, rs = sketch_method.sketch(r)
            lmbda = self.solve_system(SAS, rs)
            # updating the iterates
            diff_iterates = w - w_previous
            w_previous = w.copy()  # update w^{k-1} <- w^k
            sketch_method.update_iterate(w, lmbda, step_size=self.step_size)
            w += self.mom_beta * diff_iterates  # update w^{k+1} <- w^k
            # updating the residuals
            r_tmp = r.copy()
            r *= 1.0 + self.mom_beta
            r -= self.mom_beta * r_previous
            r -= self.step_size * safe_sparse_dot(SA.T, lmbda)
            r_previous = r_tmp.copy()
            r_norm = np.linalg.norm(r)
            err = r_norm / r_norm_initial
            self.residual_norms.append(err)
            if err < self.tol:
                break
        self.iterations = i
        return w