in syne_tune/optimizer/schedulers/searchers/bayesopt/gpautograd/learncurve/freeze_thaw.py [0:0]
def resource_kernel_likelihood_computations(
precomputed: Dict, res_kernel: ExponentialDecayBaseKernelFunction,
noise_variance, skip_c_d: bool = False) -> Dict:
"""
Given `precomputed` from `resource_kernel_likelihood_precomputations` and
resource kernel function `res_kernel`, compute quantities required for
inference and marginal likelihood computation, pertaining to the likelihood
of a additive model, as in the Freeze-Thaw paper.
Note that `res_kernel` takes raw (unnormalized) r as inputs. The code here
works for any resource kernel and mean function, not just for
:class:`ExponentialDecayBaseKernelFunction`.
Results returned are:
- c: n vector [c_i]
- d: n vector [d_i], positive
- vtv: n vector [|v_i|^2]
- wtv: (n, F) matrix[(W_i)^T v_i], F number of fantasy samples
- wtw: n vector [|w_i|^2] (only if no fantasizing)
- lfact_all: Cholesky factor for kernel matrix
- ydims: Target vector sizes (copy from `precomputed`)
:param precomputed: Output of `resource_kernel_likelihood_precomputations`
:param res_kernel: Kernel k(r, r') over resources
:param noise_variance: Noise variance sigma^2
:param skip_c_d: If True, c and d are not computed
:return: Quantities required for inference and learning criterion
"""
num_configs = precomputed['num_configs']
num_all_configs = num_configs[0]
r_min, r_max = res_kernel.r_min, res_kernel.r_max
num_res = r_max + 1 - r_min
assert num_all_configs > 0, "targets must not be empty"
assert num_res > 0, f"r_min = {r_min} must be <= r_max = {r_max}"
num_fantasy_samples = precomputed['yflat'].shape[1]
compute_wtw = num_fantasy_samples == 1
# Compute Cholesky factor for largest target vector size, or for full size
ydims = precomputed['ydims']
rvals = _colvec(anp.arange(r_min, r_min + num_res))
means_all = _flatvec(res_kernel.mean_function(rvals))
amat = res_kernel(rvals, rvals) / noise_variance + anp.diag(
anp.ones(num_res))
# TODO: Do we need AddJitterOp here?
lfact_all = cholesky_factorization(amat) # L (Cholesky factor)
# Loop over ydim
yflat = precomputed['yflat']
off = num_all_configs
ilscal = 1.0 / lfact_all[0, 0]
vvec = anp.array([ilscal]).reshape((1, 1))
# `yflat` is a (*, F) matrix, where F == `num_fantasy_samples`. These
# matrices are flattened out as rows of `wmat`, and reshaped back before
# writing into `wtv_lst`
wmat = _rowvec(yflat[:off, :] - means_all[0]) * ilscal
# Note: We need the detour via `wtv_lst`, etc, because `autograd` does not
# support overwriting the content of an `ndarray`. Their role is to collect
# parts of the final vectors, in reverse ordering
wtv_lst = []
wtw_lst = []
num_prev = off
for ydim, num in enumerate(num_configs[1:], start=1):
if num < num_prev:
# These parts are done:
pos = num * num_fantasy_samples
wdone = wmat[:, pos:]
wtv_part = anp.reshape(
anp.matmul(vvec, wdone), (-1, num_fantasy_samples))
wtv_lst.append(wtv_part)
if compute_wtw:
wtw_lst.append(_flatvec(anp.sum(anp.square(wdone), axis=0)))
wmat = wmat[:, :pos]
num_prev = num
# Update W matrix
rhs = _rowvec(yflat[off:(off + num), :] - means_all[ydim])
off += num
lvec = _rowvec(lfact_all[ydim, :ydim])
ilscal = 1.0 / lfact_all[ydim, ydim]
w_new = (rhs - anp.matmul(lvec, wmat)) * ilscal
wmat = anp.concatenate((wmat, w_new), axis=0)
# Update v vector (row vector)
v_new = anp.array(
[(1.0 - _inner_product(lvec, vvec)) * ilscal]).reshape((1, 1))
vvec = anp.concatenate((vvec, v_new), axis=1)
wtv_part = anp.reshape(
anp.matmul(vvec, wmat), (-1, num_fantasy_samples))
wtv_lst.append(wtv_part)
wtv_all = anp.concatenate(tuple(reversed(wtv_lst)), axis=0)
if compute_wtw:
wtw_lst.append(_flatvec(anp.sum(anp.square(wmat), axis=0)))
wtw_all = anp.concatenate(tuple(reversed(wtw_lst)), axis=0)
vtv_for_ydim = anp.cumsum(anp.square(vvec))
vtv_all = anp.array([vtv_for_ydim[ydim - 1] for ydim in ydims])
# Compile results
result = {
'num_data': sum(ydims),
'vtv': vtv_all,
'wtv': wtv_all,
'lfact_all': lfact_all,
'means_all': means_all,
'ydims': ydims}
if compute_wtw:
result['wtw'] = wtw_all
if not skip_c_d:
result['c'] = anp.zeros(num_all_configs)
result['d'] = anp.zeros(num_all_configs)
return result