def get_bleu4()

in retrieval_eval_bleu.py [0:0]


def get_bleu4(split, history_len=1):
    """
    Print BLEU scores and output contexts and retrieved responses.
    """
    if history_len < 1:
        history_len = 1
    source_ct = [0, 0, 0]
    net_parlai_dict = ParlAIDictionary.create_from_reddit_style(net_dictionary)
    bleu_parlai_dict = ParlAIDictionary.create_from_reddit_style(bleu_dictionary)
    scorer = bleu.Scorer(BLEU_PAD_IDX, BLEU_EOS_IDX, BLEU_UNK_IDX)
    outf = open("retrieved_split_" + args.name + "_" + split + ".txt", "w")

    def _get_dataset(reddit_dict, parlai_dict):
        if args.task == "dailydialog":
            return DDDataset(
                split,
                parlai_dict,
                data_folder=args.dailydialog_folder,
                history_len=history_len,
            )
        elif args.task == "empchat":
            return EmpDataset(
                split,
                parlai_dict,
                data_folder=args.empchat_folder,
                history_len=history_len,
                reactonly=args.reactonly,
                fasttext=args.fasttext,
                fasttext_type=args.fasttext_type,
                fasttext_path=args.fasttext_path,
            )
        elif args.task == "reddit":
            return RedditDataset(
                data_folder=args.reddit_folder,
                chunk_id=999,
                dict_=reddit_dict,
                max_hist_len=history_len,
                rm_blank_sentences=True,
            )
        else:
            raise ValueError("Task unrecognized!")

    net_dataset = _get_dataset(net_dictionary, net_parlai_dict)
    bleu_dataset = _get_dataset(bleu_dictionary, bleu_parlai_dict)
    sample_index = range(len(bleu_dataset))
    for data_idx in sample_index:
        net_context, _ = net_dataset[data_idx][:2]
        bleu_context, bleu_sentence = bleu_dataset[data_idx][:2]
        target_tokens = bleu_sentence
        if args.fasttext is not None:
            target_tokens = target_tokens[args.fasttext :]
        context = bleu_parlai_dict.vec2txt(bleu_context.numpy().tolist())
        responses, sources = predict(net_context)
        response = responses[0][0]
        source = sources[0]
        if source == "Reddit":
            source_ct[0] += 1
        elif source == "EmpChat":
            source_ct[1] += 1
        else:
            source_ct[2] += 1
        if args.task == "empchat":
            cid, sid = bleu_dataset.getid(data_idx)
        else:
            cid = sid = -1
            # This is a hack, because the other datasets have no .getid() method
        if args.fasttext is not None:
            response = " ".join(response.split()[args.fasttext :])
        outf.write("\t".join([str(cid), str(sid), context, response, source]) + "\n")
        hypo_tokens = torch.IntTensor(bleu_parlai_dict.txt2vec(response))
        # Use this tokenization even if a BERT tokenizer exists, to match the BLEU
        # calculation when not using BERT
        scorer.add(target_tokens.type(torch.IntTensor), hypo_tokens)
    print(scorer.result_string(order=1))
    print(scorer.result_string(order=2))
    print(scorer.result_string(order=3))
    print(scorer.result_string(order=4))
    print(actual_ct)
    print(
        f"EmpatheticDialogues {int(source_ct[1]):d}: selected "
        f"{float(source_ct[1]) / sum(source_ct)}%, but total: "
        f"{float(actual_ct[1]) / sum(actual_ct)}"
    )
    print(
        f"DailyDialog {int(source_ct[2]):d}: selected "
        f"{float(source_ct[2]) / sum(source_ct)}%, but total: "
        f"{float(actual_ct[2]) / sum(actual_ct)}"
    )
    print(
        f"Reddit {int(source_ct[0]):d}: selected "
        f"{float(source_ct[0]) / sum(source_ct)}%, but total: "
        f"{float(actual_ct[0]) / sum(actual_ct)}"
    )