def _run_q_gen_process_local()

in preprocess/sm_inference_asum.py [0:0]


def _run_q_gen_process_local(job_idx, *, input_source_file, input_ans_file, out_text_file,
                             offset, end, checkpoint_dir, ckp_file, bin_dir, args):
    bart = BARTModel.from_pretrained(
        checkpoint_dir,
        checkpoint_file=ckp_file,
        data_name_or_path=bin_dir
    )
    torch.cuda.set_device(torch.device("cuda:{}".format(job_idx)))
    bart.cuda()
    bart.eval()
    bart.half()

    special_token_id = 50009
    from fairseq.data.encoders.gpt2_bpe import get_encoder
    bpe = get_encoder(args.encoder_json, args.vocab_bpe)

    def _filter_ans_list(ans_list, black_set, num_ans_target):
        if len(ans_list) == 0:
            return []
        good_ans = []
        bad_ans = []
        for ans in ans_list:
            if ans.lower() in black_set:
                bad_ans.append(ans)
            else:
                good_ans.append(ans)

        if len(good_ans) >= num_ans_target:
            result = good_ans[:num_ans_target]
        elif len(ans_list) >= num_ans_target:
            result = good_ans + bad_ans
            result = result[:num_ans_target]
        else:
            result = ans_list + random.choices(ans_list, k=num_ans_target-len(ans_list))
        return result

    count = 0
    # bsz = 32
    bsz = args.bsz
    print("Local worker is processing {}-{}".format(offset, end))
    with torch.no_grad():
        with open(input_source_file, 'r') as source_f, \
                open(input_ans_file, 'r') as ans_f, \
                open(out_text_file, 'w') as out_text_f:
            for _ in range(offset):
                source_f.readline()
                ans_f.readline()
            source_line = source_f.readline()
            ans_line = ans_f.readline()
            while source_line:
                if offset + count >= end:
                    break
                source_item = json.loads(source_line.strip())
                ans_item = json.loads(ans_line.strip())

                assert len(source_item['summaries']) == len(ans_item)

                input_buffer = []
                hypo_id = 0
                # ans_dict = {}
                # filtered_ans_list = []
                for source_text, ans_list_hypo in zip(source_item['summaries'], ans_item):
                    # return 10 answers: try to avoid stop words; upsampling if not enough
                    # or return empty list if no answer available
                    if ans_list_hypo == []:
                        print("Answer span is empty!")
                        print(ans_item)
                    filtered_ans_list_hypo = _filter_ans_list(ans_list_hypo, STOP_WORDS, 10)
                    # filtered_ans_list.append(filtered_ans_list_hypo)

                    for answer in filtered_ans_list_hypo:
                        # if answer not in ans_dict:
                            # ans_dict[answer] = None
                        _, source_answer_bpe = _format_source_answers_bpe(bpe, source_text, answer, special_token_id)
                        input_buffer.append((hypo_id, answer, ' '.join(map(str, source_answer_bpe))))
                    hypo_id += 1

                all_hypos = []
                all_ids = []
                slines = []
                pa_ids = []
                # print("len of input_buffer {} = {}".format(job_idx, len(input_buffer)))
                for i in range(len(input_buffer)):
                    if i % bsz == 0 and i != 0:
                        hypotheses_batch, score_batch, unnormalized_score_batch = bart.sample(slines,
                                                                    beam=args.beam,
                                                                    lenpen=1.0,
                                                                    max_len_b=args.max_len,
                                                                    min_len=args.min_len,
                                                                    sampling=args.sampling,
                                                                    sampling_topk=args.sampling_topk,
                                                                    sampling_topp=args.sampling_topp,
                                                                    return_all=True,
                                                                    input_is_bpe=True
                                                                    )
                        assert len(hypotheses_batch) == len(score_batch) == len(unnormalized_score_batch), \
                            "lens not equal: {} and {} and {}".format(
                                len(hypotheses_batch), len(score_batch), len(unnormalized_score_batch)
                            )
                        assert len(hypotheses_batch) == len(slines), "slines={}, generated_score length={}".format(
                            slines, len(hypotheses_batch)
                        )

                        for t, s in zip(hypotheses_batch, score_batch):
                            all_hypos.append((t, s))

                        for id in pa_ids:
                            all_ids.append(id)

                        slines = []
                        pa_ids = []
                    slines.append(input_buffer[i][2])
                    pa_ids.append((input_buffer[i][0], input_buffer[i][1]))
                if slines != []:
                    hypotheses_batch, score_batch, unnormalized_score_batch = bart.sample(slines,
                                                                beam=args.beam,
                                                                lenpen=1.0,
                                                                max_len_b=args.max_len,
                                                                min_len=args.min_len,
                                                                sampling=args.sampling,
                                                                sampling_topk=args.sampling_topk,
                                                                sampling_topp=args.sampling_topp,
                                                                return_all=True,
                                                                input_is_bpe=True
                                                                )
                    assert len(hypotheses_batch) == len(score_batch) == len(unnormalized_score_batch), \
                        "lens not equal: {} and {} and {}".format(
                            len(hypotheses_batch), len(score_batch), len(unnormalized_score_batch)
                        )
                    assert len(hypotheses_batch) == len(slines), "slines={}, generated_score length={}".format(
                        slines, len(hypotheses_batch)
                    )

                    for t, s in zip(hypotheses_batch, score_batch):
                        all_hypos.append((t, s))

                    for id in pa_ids:
                        all_ids.append(id)

                # for id, hypo in zip(all_ids, all_hypos):
                #     ans_dict[id[1]] = {'questions': hypo[0], 'scores': hypo[1]}
                #
                # qa_item = {'question_dict': ans_dict, 'answers': filtered_ans_list}
                qa_item = []
                if all_ids != [] and all_hypos != []:
                    hypo_id = all_ids[0][0]
                    qa_list_hypo = []
                    for id, hypo in zip(all_ids, all_hypos):
                        if id[0] == hypo_id:
                            qa_list_hypo.append({'questions': hypo[0], 'q_scores': hypo[1], 'answer': id[1]})
                        else:
                            qa_item.append({'context': source_item['summaries'][hypo_id],
                                            'qa_list': qa_list_hypo})
                            qa_list_hypo = []
                            qa_list_hypo.append({'questions': hypo[0], 'q_scores': hypo[1], 'answer': id[1]})
                            hypo_id = id[0]
                    qa_item.append({'context': source_item['summaries'][hypo_id],
                                    'qa_list': qa_list_hypo})

                json.dump(qa_item, out_text_f)
                out_text_f.write('\n')

                source_line = source_f.readline()
                ans_line = ans_f.readline()
                count += 1
                if count % 100 == 0:
                    print("Generated {} lines from worker {}".format(count, job_idx))

        assert offset + count == end, "!worker ended at {}, should have been {}".format(
            offset + count,
            end
        )
        del bart
        torch.cuda.empty_cache()