EncryptedMatrix LinearAlgebra::multiply_common()

in src/hit/api/linearalgebra/linearalgebra.cpp [779:823]


    EncryptedMatrix LinearAlgebra::multiply_common(const EncryptedMatrix &enc_mat_a_trans,
                                                   const EncryptedMatrix &enc_mat_b, double scalar,
                                                   bool transpose_unit) {
        // This function requires b to be at one level below enc_mat_a_trans.

        // we will iterate over all columns of A^T (rows of A)
        // and compute the k^th row of A times B
        // then combine the results for each row to get the matrix product
        vector<EncryptedColVector> row_results(enc_mat_a_trans.width());

        parallel_for(enc_mat_a_trans.width(), [&](int k) {
            row_results[k] = matrix_matrix_mul_loop_row_major(enc_mat_a_trans, enc_mat_b, scalar, k, transpose_unit);
        });

        // row_results[i] contains a *single* row (possibily distributed across several cts)
        // containing the i^th row of A times the matrix B
        // The next step is to add unit.encoding_height of these together to make a single unit
        EncodingUnit unit = enc_mat_a_trans.encoding_unit();

        if (transpose_unit) {
            unit = unit.transpose();
        }

        int result_vertical_units = ceil(enc_mat_a_trans.width() / static_cast<double>(unit.encoding_height()));
        vector<vector<CKKSCiphertext>> matrix_cts(result_vertical_units);

        for (int i = 0; i < result_vertical_units; i++) {
            // this is the ColVector containing the first row of this horizontal unit
            EncryptedColVector unit_row_i_cts = row_results[i * unit.encoding_height()];
            for (int j = 1; j < unit.encoding_height(); j++) {
                // there are exactly enc_mat_a_trans.width items in row_results, but this may not correspond
                // to the number of rows in the encoding units (because some rows at the end may be 0-padding)
                // thus, we need to break once we add all the ciphertexts in row_results
                // this will break out of the inner loop, but the outer loop will immediately exit because
                // the inner loop can only break when i = result_vertical_units-1
                if (i * unit.encoding_height() + j >= enc_mat_a_trans.width()) {
                    break;
                }
                add_inplace(unit_row_i_cts, row_results[i * unit.encoding_height() + j]);
            }
            matrix_cts[i] = unit_row_i_cts.cts;
        }

        return EncryptedMatrix(enc_mat_a_trans.width(), enc_mat_b.width(), unit, matrix_cts);
    }