in src/hit/api/linearalgebra/linearalgebra.cpp [842:872]
EncryptedMatrix LinearAlgebra::multiply_row_major_mixed_unit(const EncryptedMatrix &enc_mat_a_trans,
const EncryptedMatrix &enc_mat_b, double scalar) {
matrix_multiply_validation(enc_mat_a_trans, enc_mat_b, "multiply_row_major_mixed_unit");
if (enc_mat_a_trans.he_level() != enc_mat_b.he_level() + 1) {
LOG_AND_THROW_STREAM(
"Second argument to multiply_row_major_mixed_unit must be one level below first argument: "
<< enc_mat_a_trans.he_level() << "!=" << enc_mat_b.he_level() << "+1");
}
if (enc_mat_a_trans.height() != enc_mat_b.height()) {
LOG_AND_THROW_STREAM("Inputs to multiply_row_major_mixed_unit do not have compatible dimensions: "
<< dim_string(enc_mat_a_trans) << " vs " << dim_string(enc_mat_b));
}
// inputs are encoded with an n-by-m unit where we require m <= n
EncodingUnit unit = enc_mat_a_trans.encoding_unit();
if (unit.encoding_width() > unit.encoding_height()) {
LOG_AND_THROW_STREAM("Inputs to multiply_row_major_mixed_unit are encoded with an invalid " +
dim_string(unit));
}
// A^T is g-by-f, B is g-by-h; we require f,h <= m
if (enc_mat_a_trans.width() > unit.encoding_width() || enc_mat_b.width() > unit.encoding_width()) {
LOG_AND_THROW_STREAM("Inputs to multiply_row_major_mixed_unit do not have valid dimensions: The "
<< enc_mat_a_trans.width() << "-by-" << enc_mat_b.width()
<< " output must fit into a single " << unit.encoding_width() << "-by-"
<< unit.encoding_height() << " unit and a single " << unit.encoding_height() << "-by-"
<< unit.encoding_width() << " unit");
}
// Multiply each row of A by the matrix B. The result is a list of EncryptedColVectors, each with a single
// non-zero row, then sum the results.
return multiply_common(enc_mat_a_trans, enc_mat_b, scalar, true);
}