void gmm_objective_d()

in src/cpp/modules/manual/gmm_d.cpp [88:170]


void gmm_objective_d(int d, int k, int n,
    const double *alphas,
    const double *means,
    const double *icf,
    const double *x,
    Wishart wishart,
    double *err,
    double *J)
{
    const double CONSTANT = -n * d*0.5*log(2 * PI);
    int icf_sz = d * (d + 1) / 2;

    vector<double> Qdiags(d*k);
    vector<double> sum_qs(k);
    vector<double> main_term(k);
    vector<double> xcentered(d);
    vector<double> Qxcentered(d);

    preprocess_qs(d, k, icf, sum_qs.data(), Qdiags.data());

    std::fill(J, J + (k + d * k + icf_sz * k), 0.0);

    vector<double> curr_means_d(d*k);
    vector<double> curr_q_d(d*k);
    vector<double> curr_L_d((icf_sz - d) * k);

    double *alphas_d = J;
    double *means_d = &J[k];
    double *icf_d = &J[k + d * k];

    double slse = 0.;
    for (int ix = 0; ix < n; ix++)
    {
        const double* const curr_x = &x[ix*d];
        for (int ik = 0; ik < k; ik++)
        {
            int icf_off = ik * icf_sz;
            double *Qdiag = &Qdiags[ik*d];

            subtract(d, curr_x, &means[ik*d], xcentered.data());
            Qtimesx(d, Qdiag, &icf[ik*icf_sz + d], xcentered.data(), Qxcentered.data());
            Qtransposetimesx(d, Qdiag, &icf[icf_off], Qxcentered.data(), &curr_means_d[ik*d]);
            compute_q_inner_term(d, Qdiag, xcentered.data(), Qxcentered.data(), &curr_q_d[ik*d]);
            compute_L_inner_term(d, xcentered.data(), Qxcentered.data(), &curr_L_d[ik*(icf_sz - d)]);
            main_term[ik] = alphas[ik] + sum_qs[ik] - 0.5*sqnorm(d, Qxcentered.data());
        }
        slse += logsumexp_d(k, main_term.data(), main_term.data());
        for (int ik = 0; ik < k; ik++)
        {
            int means_off = ik * d;
            int icf_off = ik * icf_sz;
            alphas_d[ik] += main_term[ik];
            for (int id = 0; id < d; id++)
            {
                means_d[means_off + id] += curr_means_d[means_off + id] * main_term[ik];
                icf_d[icf_off + id] += curr_q_d[ik*d + id] * main_term[ik];
            }
            for (int i = d; i < icf_sz; i++)
            {
                icf_d[icf_off + i] += curr_L_d[ik*(icf_sz - d) + (i - d)] * main_term[ik];
            }
        }
    }

    vector<double> lse_alphas_d(k);
    double lse_alphas = logsumexp_d(k, alphas, lse_alphas_d.data());
    for (int ik = 0; ik < k; ik++)
    {
        alphas_d[ik] -= n * lse_alphas_d[ik];
        for (int id = 0; id < d; id++)
        {
            icf_d[ik*icf_sz + id] += wishart.gamma*wishart.gamma * Qdiags[ik*d + id] * Qdiags[ik*d + id]
                - wishart.m;
        }
        for (int i = d; i < icf_sz; i++)
        {
            icf_d[ik*icf_sz + i] += wishart.gamma*wishart.gamma*icf[ik*icf_sz + i];
        }
    }

    *err = CONSTANT + slse - n * lse_alphas;
    *err += log_wishart_prior(d, k, wishart, sum_qs.data(), Qdiags.data(), icf);
}