EncryptedRowVector LinearAlgebra::matrix_matrix_mul_loop_col_major()

in src/hit/api/linearalgebra/linearalgebra.cpp [605:647]


    EncryptedRowVector LinearAlgebra::matrix_matrix_mul_loop_col_major(const EncryptedMatrix &enc_mat_a,
                                                                       const EncryptedMatrix &enc_mat_b_trans,
                                                                       double scalar, int k) {
        EncryptedColVector kth_col_B = extract_col(enc_mat_b_trans, k);
        EncodingUnit unit = enc_mat_a.encoding_unit();

        // We could just use `multiply` here, but it's inefficient:
        // it would call hadamard_multiply, followed by `sum_cols` to
        // create an encoding of the output vector.
        // Our goal is to output a single copy of the output column,
        // but NOT replicate it; we will add it to the other columns later
        // By manulaly performing the `sum_cols` step, we can accomplish
        // several other tasks simultaneously.
        EncryptedMatrix hmul_A_times_kth_col_B = hadamard_multiply(enc_mat_a, kth_col_B);
        relinearize_inplace(hmul_A_times_kth_col_B);
        rescale_to_next_inplace(hmul_A_times_kth_col_B);

        // create a mask for the first column
        int num_slots = enc_mat_b_trans.num_slots();
        vector<double> col_mask(num_slots);
        for (int i = 0; i < num_slots; i++) {
            if (i % unit.encoding_width() == 0) {
                col_mask[i] = scalar;
            } else {
                col_mask[i] = 0;
            }
        }

        vector<CKKSCiphertext> row_cts(enc_mat_a.num_vertical_units());
        parallel_for(enc_mat_a.num_vertical_units(), [&](int i) {
            // sum the units in this row
            CKKSCiphertext unit_sum = eval.add_many(hmul_A_times_kth_col_B.cts[i]);
            // sum the columns of the unit, putting the result in the first column
            rot(unit_sum, unit.encoding_width(), 1, true);

            // scale and mask out first column
            row_cts[i] = eval.multiply_plain(unit_sum, col_mask);
            // shift to the target column
            eval.rotate_right_inplace(row_cts[i], k % unit.encoding_width());
        });

        return EncryptedRowVector(enc_mat_a.height(), unit, row_cts);
    }