def enc_dec_step_beam_fast()

in src/evaluator.py [0:0]


    def enc_dec_step_beam_fast(self, data_type, task, scores, size=None):
        """
        Encoding / decoding step with beam generation and SymPy check.
        """
        params = self.params
        env = self.env
        encoder = (
            self.modules["encoder"].module
            if params.multi_gpu
            else self.modules["encoder"]
        )
        decoder = (
            self.modules["decoder"].module
            if params.multi_gpu
            else self.modules["decoder"]
        )
        encoder.eval()
        decoder.eval()
        assert params.eval_verbose in [0, 1, 2]
        assert params.eval_verbose_print is False or params.eval_verbose > 0
        assert task in [
            "ode_convergence_speed",
            "ode_control",
            "fourier_cond_init",
        ]

        # stats
        xe_loss = 0
        n_valid = torch.zeros(1000, dtype=torch.long)
        n_total = torch.zeros(1000, dtype=torch.long)

        # iterator
        iterator = env.create_test_iterator(
            data_type,
            task,
            data_path=self.trainer.data_path,
            batch_size=params.batch_size_eval,
            params=params,
            size=params.eval_size,
        )
        eval_size = len(iterator.dataset)

        # save beam results
        beam_log = {}
        hyps_to_eval = []

        for (x1, len1), (x2, len2), nb_ops in iterator:

            # update logs
            for i in range(len(len1)):
                beam_log[i + n_total.sum().item()] = {
                    "src": x1[1 : len1[i] - 1, i].tolist(),
                    "tgt": x2[1 : len2[i] - 1, i].tolist(),
                    "nb_ops": nb_ops[i].item(),
                    "hyps": [],
                }

            # target words to predict
            alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device)
            pred_mask = (
                alen[:, None] < len2[None] - 1
            )  # do not predict anything given the last target word
            y = x2[1:].masked_select(pred_mask[:-1])
            assert len(y) == (len2 - 1).sum().item()

            # optionally truncate input
            x1_, len1_ = x1, len1

            # cuda
            x1_, len1_, x2, len2, y = to_cuda(x1_, len1_, x2, len2, y)
            bs = len(len1)

            # forward
            encoded = encoder("fwd", x=x1_, lengths=len1_, causal=False)
            decoded = decoder(
                "fwd",
                x=x2,
                lengths=len2,
                causal=True,
                src_enc=encoded.transpose(0, 1),
                src_len=len1_,
            )
            word_scores, loss = decoder(
                "predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=True
            )

            # correct outputs per sequence / valid top-1 predictions
            t = torch.zeros_like(pred_mask, device=y.device)
            t[pred_mask] += word_scores.max(1)[1] == y
            valid = (t.sum(0) == len2 - 1).cpu().long()

            # update stats
            xe_loss += loss.item() * len(y)
            n_valid.index_add_(-1, nb_ops, valid)
            n_total.index_add_(-1, nb_ops, torch.ones_like(nb_ops))

            # update equations that were solved greedily
            for i in range(len(len1)):
                if valid[i]:
                    beam_log[i + n_total.sum().item() - bs]["hyps"].append(
                        (None, None, True)
                    )

            # continue if everything is correct. if eval_verbose, perform
            # a full beam search, even on correct greedy generations
            if valid.sum() == len(valid) and params.eval_verbose < 2:
                continue

            # invalid top-1 predictions - check if there is a solution in the beam
            invalid_idx = (1 - valid).nonzero().view(-1)
            logger.info(
                f"({n_total.sum().item()}/{eval_size}) Found "
                f"{bs - len(invalid_idx)}/{bs} valid top-1 predictions. "
                "Generating solutions ..."
            )

            # generate with beam search
            _, _, generations = decoder.generate_beam(
                encoded.transpose(0, 1),
                len1_,
                beam_size=params.beam_size,
                length_penalty=params.beam_length_penalty,
                early_stopping=params.beam_early_stopping,
                max_len=params.max_len,
            )

            # prepare inputs / hypotheses to check
            # if eval_verbose < 2, no beam search on equations solved greedily
            for i in range(len(generations)):
                if valid[i] and params.eval_verbose < 2:
                    continue
                for j, (score, hyp) in enumerate(
                    sorted(generations[i].hyp, key=lambda x: x[0], reverse=True)
                ):
                    hyps_to_eval.append(
                        {
                            "i": i + n_total.sum().item() - bs,
                            "j": j,
                            "score": score,
                            "src": x1[1 : len1[i] - 1, i].tolist(),
                            "tgt": x2[1 : len2[i] - 1, i].tolist(),
                            "hyp": hyp[1:].tolist(),
                            "task": task,
                        }
                    )

        # if the Jacobian is also predicted, only look at the eigenvalue
        if task == "ode_convergence_speed":
            sep_id = env.word2id[env.mtrx_separator]
            for x in hyps_to_eval:
                x["tgt"] = (
                    x["tgt"][x["tgt"].index(sep_id) + 1 :]
                    if sep_id in x["tgt"]
                    else x["tgt"]
                )
                x["hyp"] = (
                    x["hyp"][x["hyp"].index(sep_id) + 1 :]
                    if sep_id in x["hyp"]
                    else x["hyp"]
                )

        # solutions that perfectly match the reference with greedy decoding
        assert all(
            len(v["hyps"]) == 0
            or len(v["hyps"]) == 1
            and v["hyps"][0] == (None, None, True)
            for v in beam_log.values()
        )
        init_valid = sum(
            int(len(v["hyps"]) == 1 and v["hyps"][0][2] is True)
            for v in beam_log.values()
        )
        logger.info(
            f"Found {init_valid} solutions with greedy decoding "
            "(perfect reference match)."
        )

        # check hypotheses with multiprocessing
        eval_hyps = []
        start = time.time()
        logger.info(
            f"Checking {len(hyps_to_eval)} hypotheses for "
            f"{len(set(h['i'] for h in hyps_to_eval))} equations ..."
        )
        with ProcessPoolExecutor(max_workers=20) as executor:
            for output in executor.map(check_hypothesis, hyps_to_eval, chunksize=1):
                eval_hyps.append(output)
        logger.info(f"Evaluation done in {time.time() - start:.2f} seconds.")

        # update beam logs
        for hyp in eval_hyps:
            beam_log[hyp["i"]]["hyps"].append(
                (hyp["hyp"], hyp["score"], hyp["is_valid"])
            )

        # print beam results
        beam_valid = sum(
            int(any(h[2] for h in v["hyps"]) and v["hyps"][0][1] is not None)
            for v in beam_log.values()
        )
        all_valid = sum(int(any(h[2] for h in v["hyps"])) for v in beam_log.values())
        assert init_valid + beam_valid == all_valid
        assert len(beam_log) == n_total.sum().item()
        logger.info(
            f"Found {all_valid} valid solutions ({init_valid} with greedy decoding "
            f"(perfect reference match), {beam_valid} with beam search)."
        )

        # update valid equation statistics
        n_valid = torch.zeros(1000, dtype=torch.long)
        for i, v in beam_log.items():
            if any(h[2] for h in v["hyps"]):
                n_valid[v["nb_ops"]] += 1
        assert n_valid.sum().item() == all_valid

        # export evaluation details
        if params.eval_verbose:

            eval_path = os.path.join(
                params.dump_path, f"eval.beam.{data_type}.{task}.{scores['epoch']}"
            )

            with open(eval_path, "w") as f:

                # for each equation
                for i, res in sorted(beam_log.items()):
                    n_eq_valid = sum([int(v) for _, _, v in res["hyps"]])
                    src = idx_to_infix(env, res["src"], input=True).replace("|", " | ")
                    tgt = " ".join(env.id2word[wid] for wid in res["tgt"])
                    s = (
                        f"Equation {i} ({n_eq_valid}/{len(res['hyps'])})\n"
                        f"src={src}\ntgt={tgt}\n"
                    )
                    for hyp, score, valid in res["hyps"]:
                        if score is None:
                            assert hyp is None
                            s += f"{int(valid)} GREEDY\n"
                        else:
                            try:
                                hyp = " ".join(hyp)
                            except Exception:
                                hyp = f"INVALID OUTPUT {hyp}"
                            s += f"{int(valid)} {score :.3e} {hyp}\n"
                    if params.eval_verbose_print:
                        logger.info(s)
                    f.write(s + "\n")
                    f.flush()

            logger.info(f"Evaluation results written in {eval_path}")

        # log
        _n_valid = n_valid.sum().item()
        _n_total = n_total.sum().item()
        logger.info(
            f"{_n_valid}/{_n_total} ({100. * _n_valid / _n_total}%) equations "
            "were evaluated correctly."
        )

        # compute perplexity and prediction accuracy
        assert _n_total == eval_size
        scores[f'{data_type}_{task}_xe_loss'] = xe_loss / _n_total 
        scores[f"{data_type}_{task}_beam_acc"] = 100.0 * _n_valid / _n_total

        # per class perplexity and prediction accuracy
        for i in range(len(n_total)):
            if n_total[i].item() == 0:
                continue
            logger.info(
                f"{i}: {n_valid[i].sum().item()} / {n_total[i].item()} "
                f"({100. * n_valid[i].sum().item() / max(n_total[i].item(), 1)}%)"
            )
            scores[f"{data_type}_{task}_beam_acc_{i}"] = (
                100.0 * n_valid[i].sum().item() / max(n_total[i].item(), 1)
            )