in syne_tune/optimizer/schedulers/searchers/bayesopt/gpautograd/learncurve/issm.py [0:0]
def issm_likelihood_computations(
precomputed: Dict, issm_params: Dict, r_min: int, r_max: int,
skip_c_d: bool = False,
profiler: Optional[SimpleProfiler] = None) -> Dict:
"""
Given `precomputed` from `issm_likelihood_precomputations` and ISSM
parameters `issm_params`, compute quantities required for inference and
marginal likelihood computation, pertaining to the ISSM likelihood.
The index for r is range(r_min, r_max + 1). Observations must be contiguous
from r_min. The ISSM parameters are:
- alpha: n-vector, negative
- beta: n-vector
- gamma: scalar, positive
Results returned are:
- c: n vector [c_i], negative
- d: n vector [d_i], positive
- vtv: n vector [|v_i|^2]
- wtv: (n, F) matrix [(W_i)^T v_i], F number of fantasy samples
- wtw: n-vector [|w_i|^2] (only if no fantasizing)
:param precomputed: Output of `issm_likelihood_precomputations`
:param issm_params: Parameters of ISSM likelihood
:param r_min: Smallest resource value
:param r_max: Largest resource value
:param skip_c_d: If True, c and d are not computed
:return: Quantities required for inference and learning criterion
"""
num_all_configs = precomputed['num_configs'][0]
num_res = r_max + 1 - r_min
assert num_all_configs > 0, "targets must not be empty"
assert num_res > 0, f"r_min = {r_min} must be <= r_max = {r_max}"
num_fantasy_samples = precomputed['deltay'].shape[1]
compute_wtw = num_fantasy_samples == 1
alphas = _flatvec(issm_params['alpha'])
betas = _flatvec(issm_params['beta'])
gamma = issm_params['gamma']
n = getval(alphas.size)
assert n == num_all_configs, f"alpha.size = {n} != {num_all_configs}"
n = getval(betas.size)
assert n == num_all_configs, f"beta.size = {n} != {num_all_configs}"
if not skip_c_d:
# We could probably refactor this to fit into the loop below, but it
# seems subdominant
if profiler is not None:
profiler.start('issm_part1')
c_lst = []
d_lst = []
for i, ydim in enumerate(precomputed['ydims']):
alpha = alphas[i]
alpha_m1 = alpha - 1.0
beta = betas[i]
r_obs = r_min + ydim # Observed in range(r_min, r_obs)
assert 0 < ydim <= num_res,\
f"len(y[{i}]) = {ydim}, num_res = {num_res}"
# c_i, d_i
if ydim < num_res:
lrvec = anp.array(
[np.log(r) for r in range(r_obs, r_max + 1)]) *\
alpha_m1 + beta
c_scal = alpha * anp.exp(logsumexp(lrvec))
d_scal = anp.square(gamma * alpha) * anp.exp(
logsumexp(lrvec * 2.0))
c_lst.append(c_scal)
d_lst.append(d_scal)
else:
c_lst.append(0.0)
d_lst.append(0.0)
if profiler is not None:
profiler.stop('issm_part1')
# Loop over ydim
if profiler is not None:
profiler.start('issm_part2')
deltay = precomputed['deltay']
logr = precomputed['logr']
off_dely = num_all_configs
vvec = anp.ones(off_dely)
wmat = deltay[:off_dely, :] # [y_0]
vtv = anp.ones(off_dely)
wtv = wmat.copy()
if compute_wtw:
wtw = _flatvec(anp.square(wmat))
# Note: We need the detour via `vtv_lst`, etc, because `autograd` does not
# support overwriting the content of an `ndarray`. Their role is to collect
# parts of the final vectors, in reverse ordering
vtv_lst = []
wtv_lst = []
wtw_lst = []
alpham1s = alphas - 1
num_prev = off_dely
for num in precomputed['num_configs'][1:]:
if num < num_prev:
# Size of working vectors is shrinking
assert vtv.size == num_prev
# These parts are done: Collect them in the lists
# All vectors are resized to `num`, dropping the tails
vtv_lst.append(vtv[num:])
wtv_lst.append(wtv[num:, :])
vtv = vtv[:num]
wtv = wtv[:num, :]
if compute_wtw:
wtw_lst.append(wtw[num:])
wtw = wtw[:num]
alphas = alphas[:num]
alpham1s = alpham1s[:num]
betas = betas[:num]
vvec = vvec[:num]
wmat = wmat[:num, :]
num_prev = num
# [a_{j-1}]
off_logr = off_dely - num_all_configs
logr_curr = logr[off_logr:(off_logr + num)]
avec = alphas * anp.exp(logr_curr * alpham1s + betas)
evec = avec * gamma + 1 # [e_j]
vvec = vvec * evec # [v_j]
deltay_curr = deltay[off_dely:(off_dely + num), :]
off_dely += num
wmat = _colvec(evec) * wmat + deltay_curr + _colvec(avec) # [w_j]
vtv = vtv + anp.square(vvec)
if compute_wtw:
wtw = wtw + _flatvec(anp.square(wmat))
wtv = wtv + _colvec(vvec) * wmat
vtv_lst.append(vtv)
wtv_lst.append(wtv)
vtv_all = anp.concatenate(tuple(reversed(vtv_lst)), axis=0)
wtv_all = anp.concatenate(tuple(reversed(wtv_lst)), axis=0)
if compute_wtw:
wtw_lst.append(wtw)
wtw_all = anp.concatenate(tuple(reversed(wtw_lst)), axis=0)
if profiler is not None:
profiler.stop('issm_part2')
# Compile results
result = {
'num_data': sum(precomputed['ydims']),
'vtv': vtv_all,
'wtv': wtv_all}
if compute_wtw:
result['wtw'] = wtw_all
if not skip_c_d:
result['c'] = anp.array(c_lst)
result['d'] = anp.array(d_lst)
return result