EncryptedMatrix LinearAlgebra::multiply_row_major_mixed_unit()

in src/hit/api/linearalgebra/linearalgebra.cpp [842:872]


    EncryptedMatrix LinearAlgebra::multiply_row_major_mixed_unit(const EncryptedMatrix &enc_mat_a_trans,
                                                                 const EncryptedMatrix &enc_mat_b, double scalar) {
        matrix_multiply_validation(enc_mat_a_trans, enc_mat_b, "multiply_row_major_mixed_unit");
        if (enc_mat_a_trans.he_level() != enc_mat_b.he_level() + 1) {
            LOG_AND_THROW_STREAM(
                "Second argument to multiply_row_major_mixed_unit must be one level below first argument: "
                << enc_mat_a_trans.he_level() << "!=" << enc_mat_b.he_level() << "+1");
        }
        if (enc_mat_a_trans.height() != enc_mat_b.height()) {
            LOG_AND_THROW_STREAM("Inputs to multiply_row_major_mixed_unit do not have compatible dimensions: "
                                 << dim_string(enc_mat_a_trans) << " vs " << dim_string(enc_mat_b));
        }
        // inputs are encoded with an n-by-m unit where we require m <= n
        EncodingUnit unit = enc_mat_a_trans.encoding_unit();
        if (unit.encoding_width() > unit.encoding_height()) {
            LOG_AND_THROW_STREAM("Inputs to multiply_row_major_mixed_unit are encoded with an invalid " +
                                 dim_string(unit));
        }
        // A^T is g-by-f, B is g-by-h; we require f,h <= m
        if (enc_mat_a_trans.width() > unit.encoding_width() || enc_mat_b.width() > unit.encoding_width()) {
            LOG_AND_THROW_STREAM("Inputs to multiply_row_major_mixed_unit do not have valid dimensions: The "
                                 << enc_mat_a_trans.width() << "-by-" << enc_mat_b.width()
                                 << " output must fit into a single " << unit.encoding_width() << "-by-"
                                 << unit.encoding_height() << " unit and a single " << unit.encoding_height() << "-by-"
                                 << unit.encoding_width() << " unit");
        }

        // Multiply each row of A by the matrix B. The result is a list of EncryptedColVectors, each with a single
        // non-zero row, then sum the results.
        return multiply_common(enc_mat_a_trans, enc_mat_b, scalar, true);
    }