void matrix_compare()

in src/nv-wavenet/matrix.cpp [131:159]


void matrix_compare(const char* name, Matrix& A, Matrix& B, float max_error, bool relu) {
    assert(A.rows() == B.rows());
    assert(A.cols() == B.cols());

    printf("Comparing %s\n", name);
    for (int row =0; row < A.rows(); row++) {
        for (int col=0; col < A.cols(); col++) {
            float A_data = A.get(row,col);
            float B_data = B.get(row,col);

            bool correct = false;
            if (relu && (A_data <= 0.f || B_data <= 0.f)) {
                correct = A_data < max_error && B_data < max_error;
            } else {
                if (A_data == 0) {
                    correct = (B_data == 0);
                } else {
                    correct = (fabs(B_data/A_data)-1) <= max_error;
                }
            }
            if (!correct) {
                printf("  mismatch at %d,%d: %f vs %f\n", row, col, A_data, B_data);
                assert(false);
            }
        }
    }
    printf("  SUCCESS!\n");

}