in src/nv-wavenet/nv_wavenet_reference.cpp [58:94]
void nvWavenetLayer(int r, int batch_size, Matrix& Wprev, Matrix& Wcur, Matrix& Bh, Matrix& Lh, Matrix& Wres, Matrix& Bres, Matrix& Wskip, Matrix& Bskip, Matrix& Xtmd, Matrix& Xin, Matrix& Xout, Matrix& skipIn, Matrix& skipOut, bool lastLayer, int sample) {
Matrix a_prev(2*r, batch_size, false);
matrix_multiply(a_prev, Wprev, Xtmd);
Matrix a_cur(2*r, batch_size, false);
matrix_multiply(a_cur, Wcur, Xin);
Matrix h_prime(2*r, batch_size, false);
matrix_add(h_prime, a_prev, a_cur);
matrix_bias(h_prime, h_prime, Bh);
matrix_add(h_prime, h_prime, Lh);
Matrix h(r, batch_size, false);
for (int batch_idx=0; batch_idx<batch_size; batch_idx++) {
for (int row = 0; row < r; row++) {
h.set(row, batch_idx, tanh_proxy(h_prime.get(row+r, batch_idx)) * sigmoid_proxy(h_prime.get(row, batch_idx)));
}
}
matrix_multiply(Xout, Wres, h);
matrix_bias(Xout, Xout, Bres);
matrix_add(Xout, Xout, Xin);
matrix_multiply(skipOut, Wskip, h);
matrix_add(skipOut, skipOut, skipIn);
matrix_bias(skipOut,skipOut,Bskip);
/*if (sample >= 0 && sample <= 2) {
printf("Ref test: %f\n", h.get(0,0));
}*/
// if (lastLayer) matrix_relu(skipOut, skipOut);
}