void nvWavenetLayer()

in src/nv-wavenet/nv_wavenet_reference.cpp [58:94]


void nvWavenetLayer(int r, int batch_size, Matrix& Wprev, Matrix& Wcur, Matrix& Bh, Matrix& Lh, Matrix& Wres, Matrix& Bres, Matrix& Wskip, Matrix& Bskip, Matrix& Xtmd, Matrix& Xin, Matrix& Xout, Matrix& skipIn, Matrix& skipOut, bool lastLayer, int sample) {

    Matrix a_prev(2*r, batch_size, false);
    matrix_multiply(a_prev, Wprev, Xtmd);

    Matrix a_cur(2*r, batch_size, false);
    matrix_multiply(a_cur, Wcur, Xin);

    Matrix h_prime(2*r, batch_size, false); 

    matrix_add(h_prime, a_prev, a_cur);
    matrix_bias(h_prime, h_prime, Bh);

    matrix_add(h_prime, h_prime, Lh);

    Matrix h(r, batch_size, false);

    for (int batch_idx=0; batch_idx<batch_size; batch_idx++) {
        for (int row = 0; row < r; row++) {
            h.set(row, batch_idx, tanh_proxy(h_prime.get(row+r, batch_idx)) * sigmoid_proxy(h_prime.get(row, batch_idx)));
        }
    }

    matrix_multiply(Xout, Wres, h);
    matrix_bias(Xout, Xout, Bres);
    matrix_add(Xout, Xout, Xin);

    matrix_multiply(skipOut, Wskip, h);
    matrix_add(skipOut, skipOut, skipIn);
    matrix_bias(skipOut,skipOut,Bskip);

    /*if (sample >= 0 && sample <= 2) {
        printf("Ref test: %f\n", h.get(0,0));
    }*/

//    if (lastLayer) matrix_relu(skipOut, skipOut);
}