in src/evaluator.py [0:0]
def enc_dec_step_beam_fast(self, data_type, task, scores, size=None):
"""
Encoding / decoding step with beam generation and SymPy check.
"""
params = self.params
env = self.env
encoder = (
self.modules["encoder"].module
if params.multi_gpu
else self.modules["encoder"]
)
decoder = (
self.modules["decoder"].module
if params.multi_gpu
else self.modules["decoder"]
)
encoder.eval()
decoder.eval()
assert params.eval_verbose in [0, 1, 2]
assert params.eval_verbose_print is False or params.eval_verbose > 0
assert task in [
"ode_convergence_speed",
"ode_control",
"fourier_cond_init",
]
# stats
xe_loss = 0
n_valid = torch.zeros(1000, dtype=torch.long)
n_total = torch.zeros(1000, dtype=torch.long)
# iterator
iterator = env.create_test_iterator(
data_type,
task,
data_path=self.trainer.data_path,
batch_size=params.batch_size_eval,
params=params,
size=params.eval_size,
)
eval_size = len(iterator.dataset)
# save beam results
beam_log = {}
hyps_to_eval = []
for (x1, len1), (x2, len2), nb_ops in iterator:
# update logs
for i in range(len(len1)):
beam_log[i + n_total.sum().item()] = {
"src": x1[1 : len1[i] - 1, i].tolist(),
"tgt": x2[1 : len2[i] - 1, i].tolist(),
"nb_ops": nb_ops[i].item(),
"hyps": [],
}
# target words to predict
alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device)
pred_mask = (
alen[:, None] < len2[None] - 1
) # do not predict anything given the last target word
y = x2[1:].masked_select(pred_mask[:-1])
assert len(y) == (len2 - 1).sum().item()
# optionally truncate input
x1_, len1_ = x1, len1
# cuda
x1_, len1_, x2, len2, y = to_cuda(x1_, len1_, x2, len2, y)
bs = len(len1)
# forward
encoded = encoder("fwd", x=x1_, lengths=len1_, causal=False)
decoded = decoder(
"fwd",
x=x2,
lengths=len2,
causal=True,
src_enc=encoded.transpose(0, 1),
src_len=len1_,
)
word_scores, loss = decoder(
"predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=True
)
# correct outputs per sequence / valid top-1 predictions
t = torch.zeros_like(pred_mask, device=y.device)
t[pred_mask] += word_scores.max(1)[1] == y
valid = (t.sum(0) == len2 - 1).cpu().long()
# update stats
xe_loss += loss.item() * len(y)
n_valid.index_add_(-1, nb_ops, valid)
n_total.index_add_(-1, nb_ops, torch.ones_like(nb_ops))
# update equations that were solved greedily
for i in range(len(len1)):
if valid[i]:
beam_log[i + n_total.sum().item() - bs]["hyps"].append(
(None, None, True)
)
# continue if everything is correct. if eval_verbose, perform
# a full beam search, even on correct greedy generations
if valid.sum() == len(valid) and params.eval_verbose < 2:
continue
# invalid top-1 predictions - check if there is a solution in the beam
invalid_idx = (1 - valid).nonzero().view(-1)
logger.info(
f"({n_total.sum().item()}/{eval_size}) Found "
f"{bs - len(invalid_idx)}/{bs} valid top-1 predictions. "
"Generating solutions ..."
)
# generate with beam search
_, _, generations = decoder.generate_beam(
encoded.transpose(0, 1),
len1_,
beam_size=params.beam_size,
length_penalty=params.beam_length_penalty,
early_stopping=params.beam_early_stopping,
max_len=params.max_len,
)
# prepare inputs / hypotheses to check
# if eval_verbose < 2, no beam search on equations solved greedily
for i in range(len(generations)):
if valid[i] and params.eval_verbose < 2:
continue
for j, (score, hyp) in enumerate(
sorted(generations[i].hyp, key=lambda x: x[0], reverse=True)
):
hyps_to_eval.append(
{
"i": i + n_total.sum().item() - bs,
"j": j,
"score": score,
"src": x1[1 : len1[i] - 1, i].tolist(),
"tgt": x2[1 : len2[i] - 1, i].tolist(),
"hyp": hyp[1:].tolist(),
"task": task,
}
)
# if the Jacobian is also predicted, only look at the eigenvalue
if task == "ode_convergence_speed":
sep_id = env.word2id[env.mtrx_separator]
for x in hyps_to_eval:
x["tgt"] = (
x["tgt"][x["tgt"].index(sep_id) + 1 :]
if sep_id in x["tgt"]
else x["tgt"]
)
x["hyp"] = (
x["hyp"][x["hyp"].index(sep_id) + 1 :]
if sep_id in x["hyp"]
else x["hyp"]
)
# solutions that perfectly match the reference with greedy decoding
assert all(
len(v["hyps"]) == 0
or len(v["hyps"]) == 1
and v["hyps"][0] == (None, None, True)
for v in beam_log.values()
)
init_valid = sum(
int(len(v["hyps"]) == 1 and v["hyps"][0][2] is True)
for v in beam_log.values()
)
logger.info(
f"Found {init_valid} solutions with greedy decoding "
"(perfect reference match)."
)
# check hypotheses with multiprocessing
eval_hyps = []
start = time.time()
logger.info(
f"Checking {len(hyps_to_eval)} hypotheses for "
f"{len(set(h['i'] for h in hyps_to_eval))} equations ..."
)
with ProcessPoolExecutor(max_workers=20) as executor:
for output in executor.map(check_hypothesis, hyps_to_eval, chunksize=1):
eval_hyps.append(output)
logger.info(f"Evaluation done in {time.time() - start:.2f} seconds.")
# update beam logs
for hyp in eval_hyps:
beam_log[hyp["i"]]["hyps"].append(
(hyp["hyp"], hyp["score"], hyp["is_valid"])
)
# print beam results
beam_valid = sum(
int(any(h[2] for h in v["hyps"]) and v["hyps"][0][1] is not None)
for v in beam_log.values()
)
all_valid = sum(int(any(h[2] for h in v["hyps"])) for v in beam_log.values())
assert init_valid + beam_valid == all_valid
assert len(beam_log) == n_total.sum().item()
logger.info(
f"Found {all_valid} valid solutions ({init_valid} with greedy decoding "
f"(perfect reference match), {beam_valid} with beam search)."
)
# update valid equation statistics
n_valid = torch.zeros(1000, dtype=torch.long)
for i, v in beam_log.items():
if any(h[2] for h in v["hyps"]):
n_valid[v["nb_ops"]] += 1
assert n_valid.sum().item() == all_valid
# export evaluation details
if params.eval_verbose:
eval_path = os.path.join(
params.dump_path, f"eval.beam.{data_type}.{task}.{scores['epoch']}"
)
with open(eval_path, "w") as f:
# for each equation
for i, res in sorted(beam_log.items()):
n_eq_valid = sum([int(v) for _, _, v in res["hyps"]])
src = idx_to_infix(env, res["src"], input=True).replace("|", " | ")
tgt = " ".join(env.id2word[wid] for wid in res["tgt"])
s = (
f"Equation {i} ({n_eq_valid}/{len(res['hyps'])})\n"
f"src={src}\ntgt={tgt}\n"
)
for hyp, score, valid in res["hyps"]:
if score is None:
assert hyp is None
s += f"{int(valid)} GREEDY\n"
else:
try:
hyp = " ".join(hyp)
except Exception:
hyp = f"INVALID OUTPUT {hyp}"
s += f"{int(valid)} {score :.3e} {hyp}\n"
if params.eval_verbose_print:
logger.info(s)
f.write(s + "\n")
f.flush()
logger.info(f"Evaluation results written in {eval_path}")
# log
_n_valid = n_valid.sum().item()
_n_total = n_total.sum().item()
logger.info(
f"{_n_valid}/{_n_total} ({100. * _n_valid / _n_total}%) equations "
"were evaluated correctly."
)
# compute perplexity and prediction accuracy
assert _n_total == eval_size
scores[f'{data_type}_{task}_xe_loss'] = xe_loss / _n_total
scores[f"{data_type}_{task}_beam_acc"] = 100.0 * _n_valid / _n_total
# per class perplexity and prediction accuracy
for i in range(len(n_total)):
if n_total[i].item() == 0:
continue
logger.info(
f"{i}: {n_valid[i].sum().item()} / {n_total[i].item()} "
f"({100. * n_valid[i].sum().item() / max(n_total[i].item(), 1)}%)"
)
scores[f"{data_type}_{task}_beam_acc_{i}"] = (
100.0 * n_valid[i].sum().item() / max(n_total[i].item(), 1)
)