in src/evaluator.py [0:0]
def enc_dec_step(self, data_type, task, scores):
"""
Encoding / decoding step.
"""
params = self.params
env = self.env
encoder = (
self.modules["encoder"].module
if params.multi_gpu
else self.modules["encoder"]
)
decoder = (
self.modules["decoder"].module
if params.multi_gpu
else self.modules["decoder"]
)
encoder.eval()
decoder.eval()
assert params.eval_verbose in [0, 1]
assert params.eval_verbose_print is False or params.eval_verbose > 0
assert task in [
"ode_convergence_speed",
"ode_control",
"fourier_cond_init",
]
# stats
xe_loss = 0
n_valid = torch.zeros(1000, dtype=torch.long)
n_total = torch.zeros(1000, dtype=torch.long)
# evaluation details
if params.eval_verbose:
eval_path = os.path.join(
params.dump_path, f"eval.{data_type}.{task}.{scores['epoch']}"
)
f_export = open(eval_path, "w")
logger.info(f"Writing evaluation results in {eval_path} ...")
# iterator
iterator = self.env.create_test_iterator(
data_type,
task,
data_path=self.trainer.data_path,
batch_size=params.batch_size_eval,
params=params,
size=params.eval_size,
)
eval_size = len(iterator.dataset)
for (x1, len1), (x2, len2), nb_ops in iterator:
# print status
if n_total.sum().item() % 500 < params.batch_size_eval:
logger.info(f"{n_total.sum().item()}/{eval_size}")
# target words to predict
alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device)
pred_mask = (
alen[:, None] < len2[None] - 1
) # do not predict anything given the last target word
y = x2[1:].masked_select(pred_mask[:-1])
assert len(y) == (len2 - 1).sum().item()
# optionally truncate input
x1_, len1_ = x1, len1
# cuda
x1_, len1_, x2, len2, y = to_cuda(x1_, len1_, x2, len2, y)
# forward / loss
encoded = encoder("fwd", x=x1_, lengths=len1_, causal=False)
decoded = decoder(
"fwd",
x=x2,
lengths=len2,
causal=True,
src_enc=encoded.transpose(0, 1),
src_len=len1_,
)
word_scores, loss = decoder(
"predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=True
)
# correct outputs per sequence / valid top-1 predictions
t = torch.zeros_like(pred_mask, device=y.device)
t[pred_mask] += word_scores.max(1)[1] == y
valid = (t.sum(0) == len2 - 1).cpu().long()
# export evaluation details
if params.eval_verbose:
for i in range(len(len1)):
src = idx_to_infix(env, x1[1 : len1[i] - 1, i].tolist(), True)
tgt = idx_to_infix(env, x2[1 : len2[i] - 1, i].tolist(), False)
s = (
f"Equation {n_total.sum().item() + i} "
f"({'Valid' if valid[i] else 'Invalid'})\n"
f"src={src}\ntgt={tgt}\n"
)
if params.eval_verbose_print:
logger.info(s)
f_export.write(s + "\n")
f_export.flush()
# stats
xe_loss += loss.item() * len(y)
n_valid.index_add_(-1, nb_ops, valid)
n_total.index_add_(-1, nb_ops, torch.ones_like(nb_ops))
# evaluation details
if params.eval_verbose:
f_export.close()
# log
_n_valid = n_valid.sum().item()
_n_total = n_total.sum().item()
logger.info(
f"{_n_valid}/{_n_total} ({100. * _n_valid / _n_total}%) "
"equations were evaluated correctly."
)
# compute perplexity and prediction accuracy
assert _n_total == eval_size
scores[f"{data_type}_{task}_xe_loss"] = xe_loss / _n_total
scores[f"{data_type}_{task}_acc"] = 100.0 * _n_valid / _n_total
# per class perplexity and prediction accuracy
for i in range(len(n_total)):
if n_total[i].item() == 0:
continue
scores[f"{data_type}_{task}_acc_{i}"] = (
100.0 * n_valid[i].item() / max(n_total[i].item(), 1)
)