def _run_qa_gen_process_local()

in preprocess/sm_inference_asum.py [0:0]


def _run_qa_gen_process_local(job_idx, *, input_source_file, out_text_file,
                             offset, end, checkpoint_dir, ckp_file, bin_dir, args):
    bart = BARTModel.from_pretrained(
        checkpoint_dir,
        checkpoint_file=ckp_file,
        data_name_or_path=bin_dir
    )
    torch.cuda.set_device(torch.device("cuda:{}".format(job_idx)))
    bart.cuda()
    bart.eval()
    bart.half()

    special_token_id = 50009
    from fairseq.data.encoders.gpt2_bpe import get_encoder
    bpe = get_encoder(args.encoder_json, args.vocab_bpe)

    count = 0
    # bsz = 32
    bsz = args.bsz
    print("Local worker is processing {}-{}".format(offset, end))
    with torch.no_grad():
        with open(input_source_file, 'r') as source_f, \
                open(out_text_file, 'w') as out_text_f:
            for _ in range(offset):
                source_f.readline()
            source_line = source_f.readline()
            while source_line:
                if offset + count >= end:
                    break
                source_item = json.loads(source_line.strip())

                input_buffer = []
                hypo_id = 0
                # ans_dict = {}
                # filtered_ans_list = []
                for source_text in source_item['summaries']:
                    input_buffer.append((hypo_id, source_text))
                    hypo_id += 1

                all_hypos = []
                all_ids = []
                slines = []
                pa_ids = []
                # print("len of input_buffer {} = {}".format(job_idx, len(input_buffer)))
                for i in range(len(input_buffer)):
                    if i % bsz == 0 and i != 0:
                        hypotheses_batch, score_batch, unnormalized_score_batch, pos_score_batch, tokens_batch = \
                            _sample_wrapper(
                                bart,
                                sentences=slines,
                                beam=args.beam,
                                lenpen=1.0,
                                max_len_b=args.max_len,
                                min_len=args.min_len,
                                sampling=args.sampling,
                                sampling_topk=args.sampling_topk,
                                sampling_topp=args.sampling_topp,
                                return_all=args.return_all,
                                input_is_bpe=False,
                                return_token_scores=args.return_token_scores,
                                diverse_beam_groups=args.diverse_beam_groups,
                                diverse_beam_strength=args.diverse_beam_strength,
                            )
                        assert len(hypotheses_batch) == len(score_batch) == len(unnormalized_score_batch), \
                            "lens not equal: {} and {} and {}".format(
                            len(hypotheses_batch), len(score_batch), len(unnormalized_score_batch)
                        )
                        assert len(hypotheses_batch) == len(slines), "slines={}, generated_score length={}".format(
                            slines, len(hypotheses_batch)
                        )
                        if args.return_token_scores:
                            for t, s, unnormalized_s, pos_s, toks in zip(hypotheses_batch, score_batch,
                                                                         unnormalized_score_batch,
                                                                         pos_score_batch, tokens_batch):
                                all_hypos.append((t, s, unnormalized_s, pos_s, toks))
                        else:
                            for t, s, unnormalized_s in zip(hypotheses_batch, score_batch, unnormalized_score_batch):
                                all_hypos.append((t, s, unnormalized_s))

                        for id in pa_ids:
                            all_ids.append(id)

                        slines = []
                        pa_ids = []
                    slines.append(input_buffer[i][1])
                    pa_ids.append(input_buffer[i][0])
                if slines != []:
                    hypotheses_batch, score_batch, unnormalized_score_batch, pos_score_batch, tokens_batch = \
                        _sample_wrapper(
                            bart,
                            sentences=slines,
                            beam=args.beam,
                            lenpen=1.0,
                            max_len_b=args.max_len,
                            min_len=args.min_len,
                            sampling=args.sampling,
                            sampling_topk=args.sampling_topk,
                            sampling_topp=args.sampling_topp,
                            return_all=args.return_all,
                            input_is_bpe=False,
                            return_token_scores=args.return_token_scores,
                            diverse_beam_groups=args.diverse_beam_groups,
                            diverse_beam_strength=args.diverse_beam_strength,
                        )
                    assert len(hypotheses_batch) == len(score_batch) == len(unnormalized_score_batch), \
                        "lens not equal: {} and {} and {}".format(
                            len(hypotheses_batch), len(score_batch), len(unnormalized_score_batch)
                        )
                    assert len(hypotheses_batch) == len(slines), "slines={}, generated_score length={}".format(
                        slines, len(hypotheses_batch)
                    )

                    if args.return_token_scores:
                        for t, s, unnormalized_s, pos_s, toks in zip(hypotheses_batch, score_batch,
                                                                     unnormalized_score_batch,
                                                                     pos_score_batch, tokens_batch):
                            all_hypos.append((t, s, unnormalized_s, pos_s, toks))
                    else:
                        for t, s, unnormalized_s in zip(hypotheses_batch, score_batch, unnormalized_score_batch):
                            all_hypos.append((t, s, unnormalized_s))

                    for id in pa_ids:
                        all_ids.append(id)

                # for id, hypo in zip(all_ids, all_hypos):
                #     ans_dict[id[1]] = {'questions': hypo[0], 'scores': hypo[1]}
                #
                # qa_item = {'question_dict': ans_dict, 'answers': filtered_ans_list}
                qa_item = []
                if all_ids != [] and all_hypos != []:
                    if args.return_token_scores:
                        for id, hypo in zip(all_ids, all_hypos):
                            qa_item.append({'context': source_item['summaries'][id],
                                            'qa': hypo[0], 'norm_scores': hypo[1], 'unnorm_scores': hypo[2],
                                            'pos_scores': [tmp.tolist() for tmp in hypo[3]],
                                            'toks': [tmp.tolist() for tmp in hypo[4]]})
                    else:
                        for id, hypo in zip(all_ids, all_hypos):
                            qa_item.append({'context': source_item['summaries'][id],
                                            'qa': hypo[0], 'norm_scores': hypo[1], 'unnorm_scores': hypo[2]})

                json.dump(qa_item, out_text_f)
                out_text_f.write('\n')

                source_line = source_f.readline()
                count += 1
                # if count % 100 == 0:
                #     print("Generated {} lines from worker {}".format(count, job_idx))

        assert offset + count == end, "!worker ended at {}, should have been {}".format(
            offset + count,
            end
        )
        del bart
        torch.cuda.empty_cache()