in src/nv-wavenet/matrix.cpp [131:159]
void matrix_compare(const char* name, Matrix& A, Matrix& B, float max_error, bool relu) {
assert(A.rows() == B.rows());
assert(A.cols() == B.cols());
printf("Comparing %s\n", name);
for (int row =0; row < A.rows(); row++) {
for (int col=0; col < A.cols(); col++) {
float A_data = A.get(row,col);
float B_data = B.get(row,col);
bool correct = false;
if (relu && (A_data <= 0.f || B_data <= 0.f)) {
correct = A_data < max_error && B_data < max_error;
} else {
if (A_data == 0) {
correct = (B_data == 0);
} else {
correct = (fabs(B_data/A_data)-1) <= max_error;
}
}
if (!correct) {
printf(" mismatch at %d,%d: %f vs %f\n", row, col, A_data, B_data);
assert(false);
}
}
}
printf(" SUCCESS!\n");
}