EncryptedMatrix LinearAlgebra::multiply_col_major()

in src/hit/api/linearalgebra/linearalgebra.cpp [724:776]


    EncryptedMatrix LinearAlgebra::multiply_col_major(const EncryptedMatrix &enc_mat_a,
                                                      const EncryptedMatrix &enc_mat_b_trans, double scalar) {
        matrix_multiply_validation(enc_mat_a, enc_mat_b_trans, "multiply_col_major");
        if (enc_mat_a.he_level() + 1 != enc_mat_b_trans.he_level()) {
            LOG_AND_THROW_STREAM("First argument to multiply_col_major must be one level below second argument: "
                                 << enc_mat_a.he_level() << "!=" << enc_mat_b_trans.he_level() << "+1");
        }
        if (enc_mat_a.width() != enc_mat_b_trans.width()) {
            LOG_AND_THROW_STREAM("Inputs to multiply_col_major do not have compatible dimensions: "
                                 << dim_string(enc_mat_a) + " vs " + dim_string(enc_mat_b_trans));
        }

        // Multiply the matrix A by each column of B. The result is a list of EncryptedRowVectors, each with a single
        // non-zero column. This function requires A to be at one level below enc_mat_b_trans.

        // we will iterate over all rows of B^T (columns of B)
        // and compute the k^th column of A times B
        // then combine the results for each column to get the matrix product
        vector<EncryptedRowVector> col_results(enc_mat_b_trans.height());

        parallel_for(enc_mat_b_trans.height(), [&](int k) {
            col_results[k] = matrix_matrix_mul_loop_col_major(enc_mat_a, enc_mat_b_trans, scalar, k);
        });

        // col_results[i] contains a *single* column (possibily distributed across several vertical cts)
        // containing the i^th column of A times the matrix B
        // The next step is to add unit.encoding_width of these together to make a single unit
        EncodingUnit unit = enc_mat_a.encoding_unit();
        int result_horizontal_units = ceil(enc_mat_b_trans.height() / static_cast<double>(unit.encoding_width()));
        vector<vector<CKKSCiphertext>> matrix_cts(enc_mat_a.num_vertical_units());

        // Proceed to append the individual column vectors one encoding unit row at a time
        for (int i = 0; i < result_horizontal_units; i++) {
            // this is the RowVector containing the first column of this vertical unit
            EncryptedRowVector unit_col_i_cts = col_results[i * unit.encoding_width()];
            for (int j = 1; j < unit.encoding_width(); j++) {
                // there are exactly enc_mat_b_trans.height items in col_results, but this may not correspond
                // to the number of columns in the encoding units (because some rows at the end may be 0-padding)
                // thus, we need to break once we add all the ciphertexts in col_results
                // this will break out of the inner loop, but the outer loop will immediately exit because
                // the inner loop can only break when i = result_horizontal_units-1
                if (i * unit.encoding_width() + j >= enc_mat_b_trans.height()) {
                    break;
                }
                add_inplace(unit_col_i_cts, col_results[i * unit.encoding_width() + j]);
            }
            for (int j = 0; j < enc_mat_a.num_vertical_units(); j++) {
                matrix_cts[j].push_back(unit_col_i_cts.cts[j]);
            }
        }

        return EncryptedMatrix(enc_mat_a.height(), enc_mat_b_trans.height(), unit, matrix_cts);
    }