void lstm_objective_b()

in src/cpp/modules/tapenade/lstm/lstm_b.c [303:376]


void lstm_objective_b(int l, int c, int b, const double *main_params, double *
        main_paramsb, const double *extra_params, double *extra_paramsb, 
        double *state, const double *sequence, double *loss, double *lossb) {
    int i, t;
    double total = 0.0;
    double totalb = 0.0;
    int count = 0;
    const double *input = &(sequence[0]);
    double *ypred;
    double *ypredb;
    int ii1;
    int branch;
    double* stateb = (double*)malloc(2 * l * b * sizeof(double)); /* TFIX */
    ypredb = (double *)malloc(b*sizeof(double));
    for (ii1 = 0; ii1 < b; ++ii1)
        ypredb[ii1] = 0.0;
    ypred = (double *)malloc(b*sizeof(double));
    double *ynorm;
    double *ynormb;
    ynormb = (double *)malloc(b*sizeof(double));
    for (ii1 = 0; ii1 < b; ++ii1)
        ynormb[ii1] = 0.0;
    ynorm = (double *)malloc(b*sizeof(double));
    const double* ygold = NULL; /* TFIX */
    double lse;
    double lseb;
    for (t = 0; t <= (c-1)*b-1; t += b) {
        if (ypred) {
            pushReal8Array(ypred, b); /* TFIX */
            pushControl1b(1);
        } else
            pushControl1b(0);
        pushReal8Array(state, 2 * b * l); /* TFIX */
        lstm_predict_nodiff(l, b, main_params, extra_params, state, input, 
                            ypred);
        pushPointer8(ygold);
        ygold = &(sequence[t + b]);
        count = count + b;
        pushPointer8(input);
        input = ygold;
    }
    totalb = -(*lossb/count);
    *lossb = 0.0;
    for (ii1 = 0; ii1 < 8 * l * b; ii1++) /* TFIX */
        main_paramsb[ii1] = 0.0;
    for (ii1 = 0; ii1 < 3 * b; ii1++) /* TFIX */
        extra_paramsb[ii1] = 0.0;
    for (t = 0; t < 2 * l * b; t++) /* TFIX */
        stateb[t] = 0.0;
    for (t = (c-1)*b-((c-1)*b-1)%b-1; t >= 0; t += -b) { /* TFIX */
        popPointer8((void **)&input);
        for (i = b-1; i > -1; --i)
            ynormb[i] = ynormb[i] + ygold[i]*totalb;
        popPointer8((void **)&ygold);
        lseb = 0.0;
        for (i = b-1; i > -1; --i) {
            ypredb[i] = ypredb[i] + ynormb[i];
            lseb = lseb - ynormb[i];
            ynormb[i] = 0.0;
        }
        logsumexp_b(ypred, ypredb, b, lseb);
        popReal8Array(state, 2 * b * l); /* TFIX */
        popControl1b(&branch);
        if (branch == 1)
            popReal8Array(ypred, b); /* TFIX */
        lstm_predict_b(l, b, main_params, main_paramsb, extra_params, 
                       extra_paramsb, state, stateb, input, ypred, ypredb);
    }
    free(ynorm);
    free(ynormb);
    free(ypred);
    free(ypredb);
    free(stateb); /* TFIX */ // Added to dispose memory allocated in repaired code
}