def eval_singletoken()

in custom/gpt2/run_gpt2.py [0:0]


def eval_singletoken(model, args, dataset_paths, train_iter=None):
    datasets = get_datasets(dataset_paths, max_len=args.batch_size_singletoken)
    eval_sampler = SequentialSampler(datasets[args.eval_split])
    eval_dataloader = DataLoader(datasets[args.eval_split], sampler=eval_sampler, batch_size=1)

    model.eval()

    logging_outputs = []
    predicted_tokens = []
    target_tokens = []
    with torch.no_grad():
        for i, batch in tqdm(enumerate(eval_dataloader), desc="Evaluating", total=len(eval_dataloader)):
            longer_sample = batch[0].cuda()
            inp = longer_sample[:, :args.batch_size_singletoken]
            model_output = model(inp)
            target = longer_sample[:, 1:]
            logits = model_output[0]
            lprobs = F.log_softmax(logits, dim=-1)
            assert lprobs.size(0) == 1, 'We work on flat sequences'
            loss = F.nll_loss(lprobs[0], target[0], reduction='sum')
            true_token_logits = -F.nll_loss(logits[0], target[0], reduction='none')

            pred = lprobs.argmax(dim=-1).view(-1).tolist()
            predicted_tokens.extend(pred)
            ntokens = inp.numel()

            logging_output = TrainingMetrics.ranking_metrics(logits[0], true_token_logits, None, ntokens, target[0])
            logging_output['loss'] = loss.item()
            logging_output['normalizer'] = ntokens
            logging_output['sample_size'] = ntokens
            logging_output['ntokens'] = ntokens
            logging_outputs.append(logging_output)

            # for human uniq
            target_tokens.extend(target.view(-1).tolist())

    logging_average = CrossEntropyCriterionWCustomMetrics.aggregate_logging_outputs(logging_outputs)
    logging_average['ppl'] = 2 ** logging_average['loss']
    logging_average['uniq'] = len(set(predicted_tokens))
    logging_average['human_uniq'] = len(set(target_tokens))

    save_singletoken_metrics(logging_average, model.config.to_dict(), args, train_iter=train_iter)
    return logging_average