in preprocess/sm_inference_asum.py [0:0]
def _run_qa_eval_gen_process_local(job_idx, *, input_source_file, input_target_file, input_qas_file, out_text_file,
offset, end, checkpoint_dir, ckp_file, bin_dir, args):
bart = BARTModel.from_pretrained(
checkpoint_dir,
checkpoint_file=ckp_file,
data_name_or_path=bin_dir
)
torch.cuda.set_device(torch.device("cuda:{}".format(job_idx)))
bart.cuda()
bart.eval()
bart.half()
def batch_for_scorer(source_tokens_list, num_source_token_list, target_tokens_list, num_target_token_list, bsz):
length = len(source_tokens_list)
s = 0
while s < length:
e = s + bsz
yield source_tokens_list[s:e], num_source_token_list[s:e], \
target_tokens_list[s:e], num_target_token_list[s:e]
s = e
special_token = 50259
count = 0
# bsz = 32
bsz = args.bsz
print("Local worker is processing {}-{}".format(offset, end))
with torch.no_grad():
with open(input_source_file, 'r') as source_f, \
open(input_qas_file, 'r') as qas_f, \
open(input_target_file, 'r') as target_f, \
open(out_text_file, 'w') as out_text_f:
for _ in range(offset):
source_f.readline()
target_f.readline()
qas_f.readline()
source_line = source_f.readline()
target_line = target_f.readline()
qas_line = qas_f.readline()
while source_line:
if offset + count >= end:
break
max_source_tokens = 1024
if args.prepend_target:
src_tokens = bart.encode(target_line.strip() + ' ' + source_line.strip(), no_bos=True,
input_is_bpe=False)
else:
src_tokens = bart.encode(source_line.strip(), no_bos=True, input_is_bpe=False)
if len(src_tokens) > max_source_tokens:
src_tokens[max_source_tokens - 1] = src_tokens[-1]
src_tokens = src_tokens if len(src_tokens) <= max_source_tokens else src_tokens[:max_source_tokens]
qas_item = json.loads(qas_line.strip())
q_tensors = []
for hypo_qas in qas_item:
for qa in hypo_qas['qas']:
q_tensor = bart.encode(qa['q'], no_bos=True, input_is_bpe=False)
q_tensor[-1] = special_token
q_tensors.append(q_tensor)
num_src_tokens = src_tokens.numel()
src_tokens_list = [src_tokens for _ in range(len(q_tensors))]
num_src_token_list = [num_src_tokens for _ in range(len(q_tensors))]
hypos = []
for s_list, num_s_list, t_list, num_t_list in batch_for_scorer(src_tokens_list, num_src_token_list,
q_tensors,
[x.numel() for x in q_tensors], bsz):
if type(s_list) is not list:
s_list = [s_list]
if type(num_s_list) is not list:
num_s_list = [num_s_list]
if type(t_list) is not list:
t_list = [t_list]
if type(num_t_list) is not list:
num_t_list = [num_t_list]
dataset = LanguagePairDataset(s_list, num_s_list,
bart.task.source_dictionary,
t_list, num_t_list,
bart.task.target_dictionary,
shuffle=False,
input_feeding=False)
sample = dataset.collater(dataset)
sample = utils.apply_to_sample(lambda tensor: tensor.cuda(), sample)
# print(sample)
gen_args = SimpleNamespace(
beam=1,
max_len_b=50,
)
generator = bart.task.build_generator(gen_args)
translations = bart.task.inference_step(
generator,
[bart.model],
sample,
prefix_tokens=sample['target']
)
translations = [v for _, v in sorted(zip(sample['id'].tolist(), translations))]
hypos += translations
qa_id = 0
for hypo_qas in qas_item:
for qa in hypo_qas['qas']:
hypo = hypos[qa_id][0]
decoded_qa = bart.decode(hypo['tokens'])
q_a_split = decoded_qa.split(' strutConnector')
if len(q_a_split) == 2 and q_a_split[0] == qa['q']:
qa['eval_ans'] = q_a_split[1]
else:
print('Error in decoded qa: {} | {}'.format(q_a_split, qa['q']))
qa['eval_ans'] = ''
qa_id += 1
# print(hypo[0]['tokens'])
# print(hypo[0]['positional_scores'])
json.dump(qas_item, out_text_f)
out_text_f.write('\n')
source_line = source_f.readline()
target_line = target_f.readline()
qas_line = qas_f.readline()
count += 1
if count % 100 == 0:
print("Generated {} lines from worker {}".format(count, job_idx))
assert offset + count == end, "!worker ended at {}, should have been {}".format(
offset + count,
end
)
del bart
torch.cuda.empty_cache()