def resource_kernel_likelihood_computations()

in syne_tune/optimizer/schedulers/searchers/bayesopt/gpautograd/learncurve/freeze_thaw.py [0:0]


def resource_kernel_likelihood_computations(
        precomputed: Dict, res_kernel: ExponentialDecayBaseKernelFunction,
        noise_variance, skip_c_d: bool = False) -> Dict:
    """
    Given `precomputed` from `resource_kernel_likelihood_precomputations` and
    resource kernel function `res_kernel`, compute quantities required for
    inference and marginal likelihood computation, pertaining to the likelihood
    of a additive model, as in the Freeze-Thaw paper.

    Note that `res_kernel` takes raw (unnormalized) r as inputs. The code here
    works for any resource kernel and mean function, not just for
    :class:`ExponentialDecayBaseKernelFunction`.

    Results returned are:
    - c: n vector [c_i]
    - d: n vector [d_i], positive
    - vtv: n vector [|v_i|^2]
    - wtv: (n, F) matrix[(W_i)^T v_i], F number of fantasy samples
    - wtw: n vector [|w_i|^2] (only if no fantasizing)
    - lfact_all: Cholesky factor for kernel matrix
    - ydims: Target vector sizes (copy from `precomputed`)

    :param precomputed: Output of `resource_kernel_likelihood_precomputations`
    :param res_kernel: Kernel k(r, r') over resources
    :param noise_variance: Noise variance sigma^2
    :param skip_c_d: If True, c and d are not computed
    :return: Quantities required for inference and learning criterion

    """
    num_configs = precomputed['num_configs']
    num_all_configs = num_configs[0]
    r_min, r_max = res_kernel.r_min, res_kernel.r_max
    num_res = r_max + 1 - r_min
    assert num_all_configs > 0, "targets must not be empty"
    assert num_res > 0, f"r_min = {r_min} must be <= r_max = {r_max}"
    num_fantasy_samples = precomputed['yflat'].shape[1]
    compute_wtw = num_fantasy_samples == 1

    # Compute Cholesky factor for largest target vector size, or for full size
    ydims = precomputed['ydims']
    rvals = _colvec(anp.arange(r_min, r_min + num_res))
    means_all = _flatvec(res_kernel.mean_function(rvals))
    amat = res_kernel(rvals, rvals) / noise_variance + anp.diag(
        anp.ones(num_res))
    # TODO: Do we need AddJitterOp here?
    lfact_all = cholesky_factorization(amat)  # L (Cholesky factor)

    # Loop over ydim
    yflat = precomputed['yflat']
    off = num_all_configs
    ilscal = 1.0 / lfact_all[0, 0]
    vvec = anp.array([ilscal]).reshape((1, 1))
    # `yflat` is a (*, F) matrix, where F == `num_fantasy_samples`. These
    # matrices are flattened out as rows of `wmat`, and reshaped back before
    # writing into `wtv_lst`
    wmat = _rowvec(yflat[:off, :] - means_all[0]) * ilscal
    # Note: We need the detour via `wtv_lst`, etc, because `autograd` does not
    # support overwriting the content of an `ndarray`. Their role is to collect
    # parts of the final vectors, in reverse ordering
    wtv_lst = []
    wtw_lst = []
    num_prev = off
    for ydim, num in enumerate(num_configs[1:], start=1):
        if num < num_prev:
            # These parts are done:
            pos = num * num_fantasy_samples
            wdone = wmat[:, pos:]
            wtv_part = anp.reshape(
                anp.matmul(vvec, wdone), (-1, num_fantasy_samples))
            wtv_lst.append(wtv_part)
            if compute_wtw:
                wtw_lst.append(_flatvec(anp.sum(anp.square(wdone), axis=0)))
            wmat = wmat[:, :pos]
            num_prev = num
        # Update W matrix
        rhs = _rowvec(yflat[off:(off + num), :] - means_all[ydim])
        off += num
        lvec = _rowvec(lfact_all[ydim, :ydim])
        ilscal = 1.0 / lfact_all[ydim, ydim]
        w_new = (rhs - anp.matmul(lvec, wmat)) * ilscal
        wmat = anp.concatenate((wmat, w_new), axis=0)
        # Update v vector (row vector)
        v_new = anp.array(
            [(1.0 - _inner_product(lvec, vvec)) * ilscal]).reshape((1, 1))
        vvec = anp.concatenate((vvec, v_new), axis=1)
    wtv_part = anp.reshape(
        anp.matmul(vvec, wmat), (-1, num_fantasy_samples))
    wtv_lst.append(wtv_part)
    wtv_all = anp.concatenate(tuple(reversed(wtv_lst)), axis=0)
    if compute_wtw:
        wtw_lst.append(_flatvec(anp.sum(anp.square(wmat), axis=0)))
        wtw_all = anp.concatenate(tuple(reversed(wtw_lst)), axis=0)
    vtv_for_ydim = anp.cumsum(anp.square(vvec))
    vtv_all = anp.array([vtv_for_ydim[ydim - 1] for ydim in ydims])
    # Compile results
    result = {
        'num_data': sum(ydims),
        'vtv': vtv_all,
        'wtv': wtv_all,
        'lfact_all': lfact_all,
        'means_all': means_all,
        'ydims': ydims}
    if compute_wtw:
        result['wtw'] = wtw_all
    if not skip_c_d:
        result['c'] = anp.zeros(num_all_configs)
        result['d'] = anp.zeros(num_all_configs)
    return result