def _run_qa_eval_process_local()

in preprocess/sm_inference_asum.py [0:0]


def _run_qa_eval_process_local(job_idx, *, input_source_file, input_target_file, input_qas_file, out_text_file,
                             offset, end, checkpoint_dir, ckp_file, bin_dir, args):
    bart = BARTModel.from_pretrained(
        checkpoint_dir,
        checkpoint_file=ckp_file,
        data_name_or_path=bin_dir
    )
    torch.cuda.set_device(torch.device("cuda:{}".format(job_idx)))
    bart.cuda()
    bart.eval()
    bart.half()

    def batch_for_scorer(source_tokens_list, num_source_token_list, target_tokens_list, num_target_token_list, bsz):
        length = len(source_tokens_list)
        s = 0
        while s < length:
            e = s + bsz
            yield source_tokens_list[s:e], num_source_token_list[s:e], \
                  target_tokens_list[s:e], num_target_token_list[s:e]
            s = e

    special_token = 50259

    count = 0
    # bsz = 32
    bsz = args.bsz
    print("Local worker is processing {}-{}".format(offset, end))
    with torch.no_grad():
        with open(input_source_file, 'r') as source_f, \
                open(input_qas_file, 'r') as qas_f, \
                open(input_target_file, 'r') as target_f, \
                open(out_text_file, 'w') as out_text_f:
            for _ in range(offset):
                source_f.readline()
                target_f.readline()
                qas_f.readline()
            source_line = source_f.readline()
            target_line = target_f.readline()
            qas_line = qas_f.readline()
            while source_line:
                if offset + count >= end:
                    break

                max_source_tokens = 1024
                if args.prepend_target:
                    src_tokens = bart.encode(target_line.strip() + ' ' + source_line.strip(), no_bos=True,
                                          input_is_bpe=False)
                else:
                    src_tokens = bart.encode(source_line.strip(), no_bos=True, input_is_bpe=False)
                if len(src_tokens) > max_source_tokens:
                    src_tokens[max_source_tokens - 1] = src_tokens[-1]
                src_tokens = src_tokens if len(src_tokens) <= max_source_tokens else src_tokens[:max_source_tokens]

                qas_item = json.loads(qas_line.strip())

                qa_tensors = []
                for hypo_qas in qas_item:
                    for qa in hypo_qas['qas']:
                        if 'toks' in qa:
                            qa_tensors.append(torch.LongTensor(qa['toks']))
                        else:
                            q_tensor = bart.encode(qa['q'], no_bos=True, input_is_bpe=False)
                            q_tensor[-1] = special_token
                            a_tensor = bart.encode(qa['a'], no_bos=True, input_is_bpe=False)
                            qa_tensors.append(torch.cat((q_tensor, a_tensor)))

                num_src_tokens = src_tokens.numel()
                src_tokens_list = [src_tokens for _ in range(len(qa_tensors))]
                num_src_token_list = [num_src_tokens for _ in range(len(qa_tensors))]
                hypos = []
                for s_list, num_s_list, t_list, num_t_list in batch_for_scorer(src_tokens_list, num_src_token_list,
                                                                               qa_tensors,
                                                                               [x.numel() for x in qa_tensors], bsz):
                    if type(s_list) is not list:
                        s_list = [s_list]
                    if type(num_s_list) is not list:
                        num_s_list = [num_s_list]
                    if type(t_list) is not list:
                        t_list = [t_list]
                    if type(num_t_list) is not list:
                        num_t_list = [num_t_list]

                    dataset = LanguagePairDataset(s_list, num_s_list,
                                                  bart.task.source_dictionary,
                                                  t_list, num_t_list,
                                                  bart.task.target_dictionary,
                                                  shuffle=False)
                    sample = dataset.collater(dataset)
                    sample = utils.apply_to_sample(lambda tensor: tensor.cuda(), sample)
                    # print(sample)
                    generator = SequenceScorer(bart.task.target_dictionary, compute_alignment=False)
                    translations = bart.task.inference_step(
                        generator,
                        [bart.model],
                        sample,
                    )
                    translations = [v for _, v in sorted(zip(sample['id'].tolist(), translations))]
                    hypos += translations
                qa_id = 0
                for hypo_qas in qas_item:
                    for qa in hypo_qas['qas']:
                        hypo = hypos[qa_id]
                        qa['eval_ns'] = hypo[0]['score'].item()
                        qa['eval_uns'] = sum(hypo[0]['positional_scores']).item()
                        special_token_loc = (hypo[0]['tokens'] == special_token).nonzero()
                        ans_scores = hypo[0]['positional_scores'][special_token_loc+1:-1]
                        qa['eval_a_uns'] = sum(ans_scores).item() if ans_scores.numel() > 0 else 0.0
                        qa['eval_a_ns'] = qa['eval_a_uns'] * 1.0 / ans_scores.numel() if ans_scores.numel() > 0 else 0.0
                        qa['eval_pos_scores'] = hypo[0]['positional_scores'].tolist()
                        qa_id += 1
                        # print(hypo[0]['tokens'])
                        # print(hypo[0]['positional_scores'])
                json.dump(qas_item, out_text_f)
                out_text_f.write('\n')

                source_line = source_f.readline()
                target_line = target_f.readline()
                qas_line = qas_f.readline()
                count += 1
                # if count % 100 == 0:
                    # print("Generated {} lines from worker {}".format(count, job_idx))

        assert offset + count == end, "!worker ended at {}, should have been {}".format(
            offset + count,
            end
        )
        del bart
        torch.cuda.empty_cache()