in src/hit/api/linearalgebra/linearalgebra.cpp [456:497]
EncryptedMatrix LinearAlgebra::hadamard_multiply(const EncryptedMatrix &enc_mat,
const EncryptedColVector &enc_vec) {
TRY_AND_THROW_STREAM(enc_mat.validate(),
"The EncryptedMatrix argument to hadamard_multiply is invalid; has it been initialized?");
TRY_AND_THROW_STREAM(
enc_vec.validate(),
"The EncryptedColVector argument to hadamard_multiply is invalid; has it been initialized?");
if (enc_mat.encoding_unit() != enc_vec.encoding_unit()) {
LOG_AND_THROW_STREAM("Inputs to hadamard_multiply must have the same units: "
<< dim_string(enc_mat.encoding_unit()) << "!=" << dim_string(enc_vec.encoding_unit()));
}
if (enc_mat.width() != enc_vec.height()) {
LOG_AND_THROW_STREAM("Inner dimension mismatch in hadamard_multiply: " + dim_string(enc_mat)
<< " is not compatible with " + dim_string(enc_vec));
}
if (enc_mat.he_level() != enc_vec.he_level()) {
LOG_AND_THROW_STREAM("Inputs to hadamard_multiply must have the same level: " << enc_mat.he_level() << "!="
<< enc_vec.he_level());
}
if (enc_mat.scale() != enc_vec.scale()) {
LOG_AND_THROW_STREAM("Inputs to hadamard_multiply must have the same scale: "
<< log2(enc_mat.scale()) << "bits != " << log2(enc_vec.scale()) << " bits");
}
if (enc_mat.needs_rescale() || enc_vec.needs_rescale()) {
LOG_AND_THROW_STREAM("Inputs to hadamard_multiply must have nominal scale: "
<< "Vector: " << enc_mat.needs_rescale() << ", Matrix: " << enc_vec.needs_rescale());
}
if (enc_mat.needs_relin() || enc_vec.needs_relin()) {
LOG_AND_THROW_STREAM("Inputs to hadamard_multiply must be linear ciphertexts: "
<< "Vector: " << enc_mat.needs_relin() << ", Matrix: " << enc_vec.needs_relin());
}
vector<vector<CKKSCiphertext>> cts = enc_mat.cts;
parallel_for(enc_mat.num_vertical_units() * enc_mat.num_horizontal_units(), [&](int i) {
int unit_row = i / enc_mat.num_horizontal_units();
int unit_col = i % enc_mat.num_horizontal_units();
eval.multiply_inplace(cts[unit_row][unit_col], enc_vec.cts[unit_col]);
});
return EncryptedMatrix(enc_mat.height(), enc_mat.width(), enc_mat.encoding_unit(), cts);
}