preprocess/sm_inference_asum.py [557:609]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                             offset, end, checkpoint_dir, ckp_file, bin_dir, args):
    bart = BARTModel.from_pretrained(
        checkpoint_dir,
        checkpoint_file=ckp_file,
        data_name_or_path=bin_dir
    )
    torch.cuda.set_device(torch.device("cuda:{}".format(job_idx)))
    bart.cuda()
    bart.eval()
    bart.half()

    def batch_for_scorer(source_tokens_list, num_source_token_list, target_tokens_list, num_target_token_list, bsz):
        length = len(source_tokens_list)
        s = 0
        while s < length:
            e = s + bsz
            yield source_tokens_list[s:e], num_source_token_list[s:e], \
                  target_tokens_list[s:e], num_target_token_list[s:e]
            s = e

    special_token = 50259

    count = 0
    # bsz = 32
    bsz = args.bsz
    print("Local worker is processing {}-{}".format(offset, end))
    with torch.no_grad():
        with open(input_source_file, 'r') as source_f, \
                open(input_qas_file, 'r') as qas_f, \
                open(input_target_file, 'r') as target_f, \
                open(out_text_file, 'w') as out_text_f:
            for _ in range(offset):
                source_f.readline()
                target_f.readline()
                qas_f.readline()
            source_line = source_f.readline()
            target_line = target_f.readline()
            qas_line = qas_f.readline()
            while source_line:
                if offset + count >= end:
                    break

                max_source_tokens = 1024
                if args.prepend_target:
                    src_tokens = bart.encode(target_line.strip() + ' ' + source_line.strip(), no_bos=True,
                                          input_is_bpe=False)
                else:
                    src_tokens = bart.encode(source_line.strip(), no_bos=True, input_is_bpe=False)
                if len(src_tokens) > max_source_tokens:
                    src_tokens[max_source_tokens - 1] = src_tokens[-1]
                src_tokens = src_tokens if len(src_tokens) <= max_source_tokens else src_tokens[:max_source_tokens]

                qas_item = json.loads(qas_line.strip())
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



preprocess/sm_inference_asum.py [687:740]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                             offset, end, checkpoint_dir, ckp_file, bin_dir, args):
    bart = BARTModel.from_pretrained(
        checkpoint_dir,
        checkpoint_file=ckp_file,
        data_name_or_path=bin_dir
    )
    torch.cuda.set_device(torch.device("cuda:{}".format(job_idx)))
    bart.cuda()
    bart.eval()
    bart.half()

    def batch_for_scorer(source_tokens_list, num_source_token_list, target_tokens_list, num_target_token_list, bsz):
        length = len(source_tokens_list)
        s = 0
        while s < length:
            e = s + bsz
            yield source_tokens_list[s:e], num_source_token_list[s:e], \
                  target_tokens_list[s:e], num_target_token_list[s:e]
            s = e

    special_token = 50259

    count = 0
    # bsz = 32
    bsz = args.bsz
    print("Local worker is processing {}-{}".format(offset, end))
    with torch.no_grad():
        with open(input_source_file, 'r') as source_f, \
                open(input_qas_file, 'r') as qas_f, \
                open(input_target_file, 'r') as target_f, \
                open(out_text_file, 'w') as out_text_f:
            for _ in range(offset):
                source_f.readline()
                target_f.readline()
                qas_f.readline()
            source_line = source_f.readline()
            target_line = target_f.readline()
            qas_line = qas_f.readline()
            while source_line:
                if offset + count >= end:
                    break

                max_source_tokens = 1024
                if args.prepend_target:
                    src_tokens = bart.encode(target_line.strip() + ' ' + source_line.strip(), no_bos=True,
                                              input_is_bpe=False)
                else:
                    src_tokens = bart.encode(source_line.strip(), no_bos=True, input_is_bpe=False)

                if len(src_tokens) > max_source_tokens:
                    src_tokens[max_source_tokens - 1] = src_tokens[-1]
                src_tokens = src_tokens if len(src_tokens) <= max_source_tokens else src_tokens[:max_source_tokens]

                qas_item = json.loads(qas_line.strip())
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



