in src/hit/api/linearalgebra/linearalgebra.cpp [724:776]
EncryptedMatrix LinearAlgebra::multiply_col_major(const EncryptedMatrix &enc_mat_a,
const EncryptedMatrix &enc_mat_b_trans, double scalar) {
matrix_multiply_validation(enc_mat_a, enc_mat_b_trans, "multiply_col_major");
if (enc_mat_a.he_level() + 1 != enc_mat_b_trans.he_level()) {
LOG_AND_THROW_STREAM("First argument to multiply_col_major must be one level below second argument: "
<< enc_mat_a.he_level() << "!=" << enc_mat_b_trans.he_level() << "+1");
}
if (enc_mat_a.width() != enc_mat_b_trans.width()) {
LOG_AND_THROW_STREAM("Inputs to multiply_col_major do not have compatible dimensions: "
<< dim_string(enc_mat_a) + " vs " + dim_string(enc_mat_b_trans));
}
// Multiply the matrix A by each column of B. The result is a list of EncryptedRowVectors, each with a single
// non-zero column. This function requires A to be at one level below enc_mat_b_trans.
// we will iterate over all rows of B^T (columns of B)
// and compute the k^th column of A times B
// then combine the results for each column to get the matrix product
vector<EncryptedRowVector> col_results(enc_mat_b_trans.height());
parallel_for(enc_mat_b_trans.height(), [&](int k) {
col_results[k] = matrix_matrix_mul_loop_col_major(enc_mat_a, enc_mat_b_trans, scalar, k);
});
// col_results[i] contains a *single* column (possibily distributed across several vertical cts)
// containing the i^th column of A times the matrix B
// The next step is to add unit.encoding_width of these together to make a single unit
EncodingUnit unit = enc_mat_a.encoding_unit();
int result_horizontal_units = ceil(enc_mat_b_trans.height() / static_cast<double>(unit.encoding_width()));
vector<vector<CKKSCiphertext>> matrix_cts(enc_mat_a.num_vertical_units());
// Proceed to append the individual column vectors one encoding unit row at a time
for (int i = 0; i < result_horizontal_units; i++) {
// this is the RowVector containing the first column of this vertical unit
EncryptedRowVector unit_col_i_cts = col_results[i * unit.encoding_width()];
for (int j = 1; j < unit.encoding_width(); j++) {
// there are exactly enc_mat_b_trans.height items in col_results, but this may not correspond
// to the number of columns in the encoding units (because some rows at the end may be 0-padding)
// thus, we need to break once we add all the ciphertexts in col_results
// this will break out of the inner loop, but the outer loop will immediately exit because
// the inner loop can only break when i = result_horizontal_units-1
if (i * unit.encoding_width() + j >= enc_mat_b_trans.height()) {
break;
}
add_inplace(unit_col_i_cts, col_results[i * unit.encoding_width() + j]);
}
for (int j = 0; j < enc_mat_a.num_vertical_units(); j++) {
matrix_cts[j].push_back(unit_col_i_cts.cts[j]);
}
}
return EncryptedMatrix(enc_mat_a.height(), enc_mat_b_trans.height(), unit, matrix_cts);
}