in src/hit/api/linearalgebra/linearalgebra.cpp [779:823]
EncryptedMatrix LinearAlgebra::multiply_common(const EncryptedMatrix &enc_mat_a_trans,
const EncryptedMatrix &enc_mat_b, double scalar,
bool transpose_unit) {
// This function requires b to be at one level below enc_mat_a_trans.
// we will iterate over all columns of A^T (rows of A)
// and compute the k^th row of A times B
// then combine the results for each row to get the matrix product
vector<EncryptedColVector> row_results(enc_mat_a_trans.width());
parallel_for(enc_mat_a_trans.width(), [&](int k) {
row_results[k] = matrix_matrix_mul_loop_row_major(enc_mat_a_trans, enc_mat_b, scalar, k, transpose_unit);
});
// row_results[i] contains a *single* row (possibily distributed across several cts)
// containing the i^th row of A times the matrix B
// The next step is to add unit.encoding_height of these together to make a single unit
EncodingUnit unit = enc_mat_a_trans.encoding_unit();
if (transpose_unit) {
unit = unit.transpose();
}
int result_vertical_units = ceil(enc_mat_a_trans.width() / static_cast<double>(unit.encoding_height()));
vector<vector<CKKSCiphertext>> matrix_cts(result_vertical_units);
for (int i = 0; i < result_vertical_units; i++) {
// this is the ColVector containing the first row of this horizontal unit
EncryptedColVector unit_row_i_cts = row_results[i * unit.encoding_height()];
for (int j = 1; j < unit.encoding_height(); j++) {
// there are exactly enc_mat_a_trans.width items in row_results, but this may not correspond
// to the number of rows in the encoding units (because some rows at the end may be 0-padding)
// thus, we need to break once we add all the ciphertexts in row_results
// this will break out of the inner loop, but the outer loop will immediately exit because
// the inner loop can only break when i = result_vertical_units-1
if (i * unit.encoding_height() + j >= enc_mat_a_trans.width()) {
break;
}
add_inplace(unit_row_i_cts, row_results[i * unit.encoding_height() + j]);
}
matrix_cts[i] = unit_row_i_cts.cts;
}
return EncryptedMatrix(enc_mat_a_trans.width(), enc_mat_b.width(), unit, matrix_cts);
}