EncryptedColVector LinearAlgebra::matrix_matrix_mul_loop_row_major()

in src/hit/api/linearalgebra/linearalgebra.cpp [654:699]


    EncryptedColVector LinearAlgebra::matrix_matrix_mul_loop_row_major(const EncryptedMatrix &enc_mat_a_trans,
                                                                       const EncryptedMatrix &enc_mat_b, double scalar,
                                                                       int k, bool transpose_unit) {
        EncryptedRowVector kth_row_A = extract_row(enc_mat_a_trans, k);
        EncryptedColVector kth_row_A_times_B = multiply(kth_row_A, enc_mat_b);
        rescale_to_next_inplace(kth_row_A_times_B);

        // kth_row_A_times_B is a column vector encoded as rows.
        // we need to mask out the desired row (but NOT replicate it; we will add it to the other rows later)

        int num_slots = enc_mat_a_trans.num_slots();

        // Currently, each row of kth_row_A_times_B is identical. We want to mask out one
        // so that we can add it to another row later to get our matrix product.
        // Create a mask for the k^th row of kth_row_A_times_B.
        // This mask is scaled by c so that we get a constant multiplication for free.
        vector<double> row_mask(num_slots);

        // both inputs have the same encoding unit
        EncodingUnit mask_unit = enc_mat_b.encoding_unit();
        if (transpose_unit) {
            // inputs have an n-by-m unit, we need to create a mask relative to an m-by-n unit
            mask_unit = mask_unit.transpose();
        }

        // row_in_unit is the row within the encoding unit that should contain the masked row
        int row_in_unit = k % mask_unit.encoding_height();

        for (int i = 0; i < mask_unit.encoding_height(); i++) {
            for (int j = 0; j < mask_unit.encoding_width(); j++) {
                if ((transpose_unit && i == k && j < mask_unit.encoding_height()) ||
                    (!transpose_unit && i == row_in_unit)) {
                    row_mask[i * mask_unit.encoding_width() + j] = scalar;
                } else {
                    row_mask[i * mask_unit.encoding_width() + j] = 0;
                }
            }
        }

        // iterate over all the (horizontally adjacent) units of this column vector to mask out the kth row
        for (auto &ct : kth_row_A_times_B.cts) {
            eval.multiply_plain_inplace(ct, row_mask);
        }

        return kth_row_A_times_B;
    }