def issm_likelihood_computations()

in syne_tune/optimizer/schedulers/searchers/bayesopt/gpautograd/learncurve/issm.py [0:0]


def issm_likelihood_computations(
        precomputed: Dict, issm_params: Dict, r_min: int, r_max: int,
        skip_c_d: bool = False,
        profiler: Optional[SimpleProfiler] = None) -> Dict:
    """
    Given `precomputed` from `issm_likelihood_precomputations` and ISSM
    parameters `issm_params`, compute quantities required for inference and
    marginal likelihood computation, pertaining to the ISSM likelihood.

    The index for r is range(r_min, r_max + 1). Observations must be contiguous
    from r_min. The ISSM parameters are:
    - alpha: n-vector, negative
    - beta: n-vector
    - gamma: scalar, positive

    Results returned are:
    - c: n vector [c_i], negative
    - d: n vector [d_i], positive
    - vtv: n vector [|v_i|^2]
    - wtv: (n, F) matrix [(W_i)^T v_i], F number of fantasy samples
    - wtw: n-vector [|w_i|^2] (only if no fantasizing)

    :param precomputed: Output of `issm_likelihood_precomputations`
    :param issm_params: Parameters of ISSM likelihood
    :param r_min: Smallest resource value
    :param r_max: Largest resource value
    :param skip_c_d: If True, c and d are not computed
    :return: Quantities required for inference and learning criterion

    """
    num_all_configs = precomputed['num_configs'][0]
    num_res = r_max + 1 - r_min
    assert num_all_configs > 0, "targets must not be empty"
    assert num_res > 0, f"r_min = {r_min} must be <= r_max = {r_max}"
    num_fantasy_samples = precomputed['deltay'].shape[1]
    compute_wtw = num_fantasy_samples == 1
    alphas = _flatvec(issm_params['alpha'])
    betas = _flatvec(issm_params['beta'])
    gamma = issm_params['gamma']
    n = getval(alphas.size)
    assert n == num_all_configs, f"alpha.size = {n} != {num_all_configs}"
    n = getval(betas.size)
    assert n == num_all_configs, f"beta.size = {n} != {num_all_configs}"

    if not skip_c_d:
        # We could probably refactor this to fit into the loop below, but it
        # seems subdominant
        if profiler is not None:
            profiler.start('issm_part1')
        c_lst = []
        d_lst = []
        for i, ydim in enumerate(precomputed['ydims']):
            alpha = alphas[i]
            alpha_m1 = alpha - 1.0
            beta = betas[i]
            r_obs = r_min + ydim  # Observed in range(r_min, r_obs)
            assert 0 < ydim <= num_res,\
                f"len(y[{i}]) = {ydim}, num_res = {num_res}"
            # c_i, d_i
            if ydim < num_res:
                lrvec = anp.array(
                    [np.log(r) for r in range(r_obs, r_max + 1)]) *\
                        alpha_m1 + beta
                c_scal = alpha * anp.exp(logsumexp(lrvec))
                d_scal = anp.square(gamma * alpha) * anp.exp(
                    logsumexp(lrvec * 2.0))
                c_lst.append(c_scal)
                d_lst.append(d_scal)
            else:
                c_lst.append(0.0)
                d_lst.append(0.0)
        if profiler is not None:
            profiler.stop('issm_part1')

    # Loop over ydim
    if profiler is not None:
        profiler.start('issm_part2')
    deltay = precomputed['deltay']
    logr = precomputed['logr']
    off_dely = num_all_configs
    vvec = anp.ones(off_dely)
    wmat = deltay[:off_dely, :]  # [y_0]
    vtv = anp.ones(off_dely)
    wtv = wmat.copy()
    if compute_wtw:
        wtw = _flatvec(anp.square(wmat))
    # Note: We need the detour via `vtv_lst`, etc, because `autograd` does not
    # support overwriting the content of an `ndarray`. Their role is to collect
    # parts of the final vectors, in reverse ordering
    vtv_lst = []
    wtv_lst = []
    wtw_lst = []
    alpham1s = alphas - 1
    num_prev = off_dely
    for num in precomputed['num_configs'][1:]:
        if num < num_prev:
            # Size of working vectors is shrinking
            assert vtv.size == num_prev
            # These parts are done: Collect them in the lists
            # All vectors are resized to `num`, dropping the tails
            vtv_lst.append(vtv[num:])
            wtv_lst.append(wtv[num:, :])
            vtv = vtv[:num]
            wtv = wtv[:num, :]
            if compute_wtw:
                wtw_lst.append(wtw[num:])
                wtw = wtw[:num]
            alphas = alphas[:num]
            alpham1s = alpham1s[:num]
            betas = betas[:num]
            vvec = vvec[:num]
            wmat = wmat[:num, :]
            num_prev = num
        # [a_{j-1}]
        off_logr = off_dely - num_all_configs
        logr_curr = logr[off_logr:(off_logr + num)]
        avec = alphas * anp.exp(logr_curr * alpham1s + betas)
        evec = avec * gamma + 1  # [e_j]
        vvec = vvec * evec  # [v_j]
        deltay_curr = deltay[off_dely:(off_dely + num), :]
        off_dely += num
        wmat = _colvec(evec) * wmat + deltay_curr + _colvec(avec)  # [w_j]
        vtv = vtv + anp.square(vvec)
        if compute_wtw:
            wtw = wtw + _flatvec(anp.square(wmat))
        wtv = wtv + _colvec(vvec) * wmat
    vtv_lst.append(vtv)
    wtv_lst.append(wtv)
    vtv_all = anp.concatenate(tuple(reversed(vtv_lst)), axis=0)
    wtv_all = anp.concatenate(tuple(reversed(wtv_lst)), axis=0)
    if compute_wtw:
        wtw_lst.append(wtw)
        wtw_all = anp.concatenate(tuple(reversed(wtw_lst)), axis=0)
    if profiler is not None:
        profiler.stop('issm_part2')

    # Compile results
    result = {
        'num_data': sum(precomputed['ydims']),
        'vtv': vtv_all,
        'wtv': wtv_all}
    if compute_wtw:
        result['wtw'] = wtw_all
    if not skip_c_d:
        result['c'] = anp.array(c_lst)
        result['d'] = anp.array(d_lst)
    return result