def aug_eval()

in train.py [0:0]


def aug_eval(partition, epoch):
    tx, ty = get_data(partition)
    if H.aug_eval_n_examples is not None:
        tx = tx[:H.aug_eval_n_examples]
        if ty is not None:
            ty = ty[:H.aug_eval_n_examples]
    gen_in_use = [gen for gen in H.self_gen_types if gen.is_used]
    if not gen_in_use:
        gen_in_use = [AugmentationType("sos", "identity", 1, True, identity)]
    aug_choices = [gen.num_tokens for gen in gen_in_use]
    for aug_types in go_over(aug_choices):
        fname = os.path.join(
            H.model_dir,
            f"{H.desc}_" + "_".join(map(str, aug_types)) + "_losses.npz")
        if os.path.exists(fname):
            if mpi_rank() == 0:
                print(f" Evaluated {fname}")
            continue
        if mpi_rank() == 0:
            print(f"Evaluating {fname}")
        losses = []
        imgs = []
        for data in iter_data_mpi(tx, n_batch=H.n_batch, log=logprint,
                                  split_by_rank=dataset.full_dataset_valid):
            feeds = {H.X_ph: data[0], H.X_emb_ph: H.x_emb}
            x_emb = np.concatenate([H.x_emb.copy() for _ in range(H.n_batch)], axis=0)
            d_in = data[0]
            if H.num_self_gen_in_use > 0:
                y_gen_list = []
                for aug_type, gen in zip(aug_types, gen_in_use):
                    if gen.sos_name == 'sos_data':
                        raise NotImplementedError("sos_data is not supported in aug_eval")
                    yy = np.full((H.n_batch, 1), aug_type, dtype=np.int32)
                    d_in, x_emb, y_gen = gen.fn(d_in, x_emb, yy=yy)
                    assert (y_gen == yy).all()
                    y_gen_list.append(y_gen)
                feeds[H.X_ph] = d_in
                if H.permute_embeddings:
                    feeds[H.X_emb_ph] = x_emb
                if not H.use_unconditional_augmentation:
                    feeds[H.Y_gen_ph] = np.concatenate(y_gen_list, axis=1)
                    assert (feeds[H.Y_gen_ph] == np.stack([aug_types] * H.n_batch)).all()
            imgs.append(d_in)
            cur_loss = sess.run(H.eval_gen_losses, feeds)
            assert cur_loss.shape[0] == H.n_batch
            losses.append(cur_loss)

        losses = np.concatenate(losses, axis=0).astype(np.float32)
        assert losses.shape[0] == tx.shape[0] // mpi_size()
        mpi_barrier()
        losses = mpi_allgather(losses)
        assert losses.shape[0] == tx.shape[0]
        loss = losses.mean()

        content = dict(epoch=epoch, aug_types=aug_types, loss=loss, bpd=loss / np.log(2.0))
        logprint(**content)
        content["losses"] = losses
        if mpi_rank() == 0:
            np.savez(fname, **content)

        imgs = np.concatenate(imgs, axis=0)
        assert imgs.shape[0] == tx.shape[0] // mpi_size()
        mpi_barrier()
        imgs = mpi_allgather(imgs)
        assert imgs.shape == tx.shape
        if mpi_rank() == 0 and partition != "test":
            fname = os.path.join(H.model_dir, f"{H.desc}_" + "_".join(map(str, aug_types)) + "_imgs.npz")
            np.savez(fname, imgs=imgs.reshape(dataset.orig_shape))
        mpi_barrier()