in preprocess/sm_inference_asum.py [0:0]
def _run_qa_gen_process_local_batch_lines(job_idx, *, input_source_file, out_text_file,
offset, end, checkpoint_dir, ckp_file, bin_dir, args):
bart = BARTModel.from_pretrained(
checkpoint_dir,
checkpoint_file=ckp_file,
data_name_or_path=bin_dir
)
torch.cuda.set_device(torch.device("cuda:{}".format(job_idx)))
bart.cuda()
bart.eval()
bart.half()
count = 1
# bsz = 32
bsz = args.bsz
print("Local worker is processing {}-{}".format(offset, end))
with torch.no_grad():
with open(input_source_file, 'r') as source_f, \
open(out_text_file, 'w') as out_text_f:
for _ in range(offset):
source_f.readline()
source_line = source_f.readline()
source_item = json.loads(source_line.strip())
assert len(source_item['summaries']) == 1
slines = [source_item['summaries'][0].strip()]
while source_line:
if offset + count >= end:
break
if count % bsz == 0:
hypotheses_batch, score_batch, unnormalized_score_batch, pos_score_batch, tokens_batch = \
_sample_wrapper(
bart,
sentences=slines,
beam=args.beam,
lenpen=1.0,
max_len_b=args.max_len,
min_len=args.min_len,
sampling=args.sampling,
sampling_topk=args.sampling_topk,
sampling_topp=args.sampling_topp,
return_all=args.return_all,
input_is_bpe=False,
return_token_scores=args.return_token_scores,
diverse_beam_groups=args.diverse_beam_groups,
diverse_beam_strength=args.diverse_beam_strength,
)
assert len(hypotheses_batch) == len(score_batch) == len(unnormalized_score_batch), \
"lens not equal: {} and {} and {}".format(
len(hypotheses_batch), len(score_batch), len(unnormalized_score_batch)
)
assert len(hypotheses_batch) == len(slines), "slines={}, generated_score length={}".format(
slines, len(hypotheses_batch)
)
if args.return_token_scores:
for t, s, unnormalized_s, pos_s, toks, sline in zip(hypotheses_batch, score_batch,
unnormalized_score_batch,
pos_score_batch, tokens_batch, slines):
qa_item = [{
'context': sline,
'qa': t if type(t) is list else [t, ],
'norm_scores': s if type(s) is list else [s, ],
'unnorm_scores': unnormalized_s if type(unnormalized_s) is list else [unnormalized_s, ],
'pos_scores': [tmp.tolist() for tmp in pos_s] if args.return_all and args.beam > 1 \
else [pos_s.tolist(), ],
'toks': [tmp.tolist() for tmp in toks] if args.return_all and args.beam > 1 else \
[toks.tolist(), ]
}, ]
json.dump(qa_item, out_text_f)
out_text_f.write('\n')
else:
for t, s, unnormalized_s, sline in zip(hypotheses_batch, score_batch, unnormalized_score_batch,
slines):
qa_item = [{
'context': sline,
'qa': t if type(t) is list else [t, ],
'norm_scores': s if type(s) is list else [s, ],
'unnorm_scores': unnormalized_s if type(unnormalized_s) is list else [unnormalized_s,]
},]
json.dump(qa_item, out_text_f)
out_text_f.write('\n')
out_text_f.flush()
slines = []
source_line = source_f.readline()
source_item = json.loads(source_line.strip())
slines.append(source_item['summaries'][0].strip())
count += 1
# if count % 100 == 0:
# print("Generated {} lines from worker {}".format(count, job_idx))
if slines != []:
hypotheses_batch, score_batch, unnormalized_score_batch, pos_score_batch, tokens_batch = \
_sample_wrapper(
bart,
sentences=slines,
beam=args.beam,
lenpen=1.0,
max_len_b=args.max_len,
min_len=args.min_len,
sampling=args.sampling,
sampling_topk=args.sampling_topk,
sampling_topp=args.sampling_topp,
return_all=args.return_all,
input_is_bpe=False,
return_token_scores=args.return_token_scores,
diverse_beam_groups=args.diverse_beam_groups,
diverse_beam_strength=args.diverse_beam_strength,
)
assert len(hypotheses_batch) == len(score_batch) == len(unnormalized_score_batch), \
"lens not equal: {} and {} and {}".format(
len(hypotheses_batch), len(score_batch), len(unnormalized_score_batch)
)
assert len(hypotheses_batch) == len(slines), "slines={}, generated_score length={}".format(
slines, len(hypotheses_batch)
)
if args.return_token_scores:
for t, s, unnormalized_s, pos_s, toks, sline in zip(hypotheses_batch, score_batch,
unnormalized_score_batch,
pos_score_batch, tokens_batch, slines):
qa_item = [{
'context': sline,
'qa': t if type(t) is list else [t, ],
'norm_scores': s if type(s) is list else [s, ],
'unnorm_scores': unnormalized_s if type(unnormalized_s) is list else [unnormalized_s, ],
'pos_scores': [tmp.tolist() for tmp in pos_s] if args.return_all and args.beam > 1 else \
[pos_s.tolist(), ],
'toks': [tmp.tolist() for tmp in toks] if args.return_all and args.beam > 1 else \
[toks.tolist(), ]
}, ]
json.dump(qa_item, out_text_f)
out_text_f.write('\n')
else:
for t, s, unnormalized_s, sline in zip(hypotheses_batch, score_batch, unnormalized_score_batch,
slines):
qa_item = [{
'context': sline,
'qa': t if type(t) is list else [t, ],
'norm_scores': s if type(s) is list else [s, ],
'unnorm_scores': unnormalized_s if type(unnormalized_s) is list else [unnormalized_s, ]
}, ]
json.dump(qa_item, out_text_f)
out_text_f.write('\n')
out_text_f.flush()
assert offset + count == end, "!worker ended at {}, should have been {}".format(
offset + count,
end
)
del bart
torch.cuda.empty_cache()