in src/nli/inference_debug.py [0:0]
def inference(model_class_name, model_checkpoint_path, max_length, premise, hypothesis, cpu=True):
parser = argparse.ArgumentParser()
args = parser.parse_args()
# CPU for now
if cpu:
args.global_rank = -1
else:
args.global_rank = 0
model_checkpoint_path = model_checkpoint_path
args.model_class_name = model_class_name
num_labels = 3
# we are doing NLI so we set num_labels = 3, for other task we can change this value.
max_length = max_length
model_class_item = MODEL_CLASSES[model_class_name]
model_name = model_class_item['model_name']
do_lower_case = model_class_item['do_lower_case'] if 'do_lower_case' in model_class_item else False
tokenizer = model_class_item['tokenizer'].from_pretrained(model_name,
cache_dir=str(config.PRO_ROOT / "trans_cache"),
do_lower_case=do_lower_case)
model = model_class_item['sequence_classification'].from_pretrained(model_name,
cache_dir=str(config.PRO_ROOT / "trans_cache"),
num_labels=num_labels)
model.load_state_dict(torch.load(model_checkpoint_path))
padding_token_value = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0]
padding_segement_value = model_class_item["padding_segement_value"]
padding_att_value = model_class_item["padding_att_value"]
left_pad = model_class_item['left_pad'] if 'left_pad' in model_class_item else False
batch_size_per_gpu_eval = 16
eval_data_list = [{
'uid': str(uuid.uuid4()),
'premise': premise,
'hypothesis': hypothesis,
'label': 'h' # hidden
}]
batching_schema = {
'uid': RawFlintField(),
'y': LabelFlintField(),
'input_ids': ArrayIndexFlintField(pad_idx=padding_token_value, left_pad=left_pad),
'token_type_ids': ArrayIndexFlintField(pad_idx=padding_segement_value, left_pad=left_pad),
'attention_mask': ArrayIndexFlintField(pad_idx=padding_att_value, left_pad=left_pad),
}
data_transformer = NLITransform(model_name, tokenizer, max_length)
d_dataset, d_sampler, d_dataloader = build_eval_dataset_loader_and_sampler(eval_data_list, data_transformer,
batching_schema,
batch_size_per_gpu_eval)
if not cpu:
torch.cuda.set_device(0)
model.cuda(0)
pred_output_list = eval_model(model, d_dataloader, args.global_rank, args)
# r_dict = dict()
# Eval loop:
# print(pred_output_list)
return pred_output_list[0]