def _run_qa_gen_process_local_batch_lines()

in preprocess/sm_inference_asum.py [0:0]


def _run_qa_gen_process_local_batch_lines(job_idx, *, input_source_file, out_text_file,
                             offset, end, checkpoint_dir, ckp_file, bin_dir, args):
    bart = BARTModel.from_pretrained(
        checkpoint_dir,
        checkpoint_file=ckp_file,
        data_name_or_path=bin_dir
    )
    torch.cuda.set_device(torch.device("cuda:{}".format(job_idx)))
    bart.cuda()
    bart.eval()
    bart.half()

    count = 1
    # bsz = 32
    bsz = args.bsz
    print("Local worker is processing {}-{}".format(offset, end))
    with torch.no_grad():
        with open(input_source_file, 'r') as source_f, \
                open(out_text_file, 'w') as out_text_f:
            for _ in range(offset):
                source_f.readline()
            source_line = source_f.readline()
            source_item = json.loads(source_line.strip())
            assert len(source_item['summaries']) == 1
            slines = [source_item['summaries'][0].strip()]
            while source_line:
                if offset + count >= end:
                    break
                if count % bsz == 0:
                    hypotheses_batch, score_batch, unnormalized_score_batch, pos_score_batch, tokens_batch = \
                        _sample_wrapper(
                            bart,
                            sentences=slines,
                            beam=args.beam,
                            lenpen=1.0,
                            max_len_b=args.max_len,
                            min_len=args.min_len,
                            sampling=args.sampling,
                            sampling_topk=args.sampling_topk,
                            sampling_topp=args.sampling_topp,
                            return_all=args.return_all,
                            input_is_bpe=False,
                            return_token_scores=args.return_token_scores,
                            diverse_beam_groups=args.diverse_beam_groups,
                            diverse_beam_strength=args.diverse_beam_strength,
                        )
                    assert len(hypotheses_batch) == len(score_batch) == len(unnormalized_score_batch), \
                        "lens not equal: {} and {} and {}".format(
                        len(hypotheses_batch), len(score_batch), len(unnormalized_score_batch)
                    )
                    assert len(hypotheses_batch) == len(slines), "slines={}, generated_score length={}".format(
                        slines, len(hypotheses_batch)
                    )
                    if args.return_token_scores:
                        for t, s, unnormalized_s, pos_s, toks, sline in zip(hypotheses_batch, score_batch,
                                                                           unnormalized_score_batch,
                                                                           pos_score_batch, tokens_batch, slines):
                            qa_item = [{
                                'context': sline,
                                'qa': t if type(t) is list else [t, ],
                                'norm_scores': s if type(s) is list else [s, ],
                                'unnorm_scores': unnormalized_s if type(unnormalized_s) is list else [unnormalized_s, ],
                                'pos_scores': [tmp.tolist() for tmp in pos_s] if args.return_all and args.beam > 1 \
                                    else [pos_s.tolist(), ],
                                'toks': [tmp.tolist() for tmp in toks] if args.return_all and args.beam > 1 else \
                                    [toks.tolist(), ]
                            }, ]
                            json.dump(qa_item, out_text_f)
                            out_text_f.write('\n')
                    else:
                        for t, s, unnormalized_s, sline in zip(hypotheses_batch, score_batch, unnormalized_score_batch,
                                                               slines):
                            qa_item = [{
                                'context': sline,
                                'qa': t if type(t) is list else [t, ],
                                'norm_scores': s if type(s) is list else [s, ],
                                'unnorm_scores':  unnormalized_s if type(unnormalized_s) is list else [unnormalized_s,]
                            },]
                            json.dump(qa_item, out_text_f)
                            out_text_f.write('\n')
                    out_text_f.flush()
                    slines = []
                source_line = source_f.readline()
                source_item = json.loads(source_line.strip())
                slines.append(source_item['summaries'][0].strip())
                count += 1
                # if count % 100 == 0:
                #     print("Generated {} lines from worker {}".format(count, job_idx))

            if slines != []:
                hypotheses_batch, score_batch, unnormalized_score_batch, pos_score_batch, tokens_batch = \
                    _sample_wrapper(
                        bart,
                        sentences=slines,
                        beam=args.beam,
                        lenpen=1.0,
                        max_len_b=args.max_len,
                        min_len=args.min_len,
                        sampling=args.sampling,
                        sampling_topk=args.sampling_topk,
                        sampling_topp=args.sampling_topp,
                        return_all=args.return_all,
                        input_is_bpe=False,
                        return_token_scores=args.return_token_scores,
                        diverse_beam_groups=args.diverse_beam_groups,
                        diverse_beam_strength=args.diverse_beam_strength,
                    )
                assert len(hypotheses_batch) == len(score_batch) == len(unnormalized_score_batch), \
                    "lens not equal: {} and {} and {}".format(
                        len(hypotheses_batch), len(score_batch), len(unnormalized_score_batch)
                    )
                assert len(hypotheses_batch) == len(slines), "slines={}, generated_score length={}".format(
                    slines, len(hypotheses_batch)
                )

                if args.return_token_scores:
                    for t, s, unnormalized_s, pos_s, toks, sline in zip(hypotheses_batch, score_batch,
                                                                        unnormalized_score_batch,
                                                                        pos_score_batch, tokens_batch, slines):
                        qa_item = [{
                            'context': sline,
                            'qa': t if type(t) is list else [t, ],
                            'norm_scores': s if type(s) is list else [s, ],
                            'unnorm_scores': unnormalized_s if type(unnormalized_s) is list else [unnormalized_s, ],
                            'pos_scores': [tmp.tolist() for tmp in pos_s] if args.return_all and args.beam > 1 else \
                                [pos_s.tolist(), ],
                            'toks': [tmp.tolist() for tmp in toks] if args.return_all and args.beam > 1 else \
                                [toks.tolist(), ]
                        }, ]
                        json.dump(qa_item, out_text_f)
                        out_text_f.write('\n')
                else:
                    for t, s, unnormalized_s, sline in zip(hypotheses_batch, score_batch, unnormalized_score_batch,
                                                           slines):
                        qa_item = [{
                            'context': sline,
                            'qa': t if type(t) is list else [t, ],
                            'norm_scores': s if type(s) is list else [s, ],
                            'unnorm_scores': unnormalized_s if type(unnormalized_s) is list else [unnormalized_s, ]
                        }, ]
                        json.dump(qa_item, out_text_f)
                        out_text_f.write('\n')
                out_text_f.flush()

        assert offset + count == end, "!worker ended at {}, should have been {}".format(
            offset + count,
            end
        )
        del bart
        torch.cuda.empty_cache()